Skip to content

Commit 7edd46b

Browse files
authored
Add files via upload
1 parent 7ea623f commit 7edd46b

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

prepare_dataset.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
# Function: Read files and preprocess them
3+
# Steps:
4+
# 1. Import data from the gdf file provided before the competition,
5+
# remove unwanted channels, and select required events.
6+
# 2. Select desired time segments for slicing; treat each segment (4s) as one sample.
7+
# 3. Import labels from the mat file provided after the competition,
8+
# ensuring they correspond with epochs and their numbers match.
9+
# 4. Save the resulting data in a new mat file,
10+
# preparing it for use in the subsequent main.py.
11+
"""
12+
13+
import mne
14+
import numpy as np
15+
import scipy.signal as signal
16+
from scipy.io import savemat
17+
import scipy.io as sio
18+
import numpy as np
19+
20+
def changeGdf2Mat(dir_path, mode="train"):
21+
'''
22+
read data from GDF files and store as mat files
23+
24+
Parameters
25+
----------
26+
dir_path : str
27+
GDF file dir path.
28+
mode : str, optional
29+
change train dataset or eval dataset. The default is "train".
30+
31+
Returns
32+
-------
33+
None.
34+
35+
'''
36+
mode_str = ''
37+
if mode=="train":
38+
mode_str = 'T'
39+
else:
40+
mode_str = 'E'
41+
for nSub in range(1, 10):
42+
# Load the gdf file
43+
data_filename = dir_path+'BCICIV_2a_gdf/A0{}{}.gdf'.format(nSub, mode_str)
44+
raw = mne.io.read_raw_gdf(data_filename)
45+
46+
# Select the events of interest
47+
events, event_dict = mne.events_from_annotations(raw)
48+
if mode=="train":
49+
# train dataset are labeled
50+
event_id = {'Left': event_dict['769'],
51+
'Right': event_dict['770'],
52+
'Foot': event_dict['771'],
53+
'Tongue': event_dict['772']}
54+
else:
55+
# evaluate dataset are labeled as 'Unknnow'
56+
event_id = {'Unknown': event_dict['783']}
57+
58+
# 选取我们关心的四个类别对应的事件,这里events[:, 2]是指events中的第三列,即事件的编号。
59+
selected_events = events[np.isin(events[:, 2], list(event_id.values()))]
60+
61+
# remove EOG channels
62+
raw.info['bads'] += ['EOG-left', 'EOG-central', 'EOG-right']
63+
picks = mne.pick_types(raw.info, meg=False, eeg=True, eog=False, stim=False, exclude='bads')
64+
# Epoch the data
65+
# using 4s (1000 sample point ) segmentation
66+
epochs = mne.Epochs(raw, selected_events, event_id, picks=picks,tmin=0, tmax=3.996, preload=True, baseline=None)
67+
68+
filtered_data = epochs.get_data()
69+
label_filename = dir_path + 'true_labels/'+'A0{}{}.mat'.format(nSub, mode_str)
70+
mat = sio.loadmat(label_filename) # load target mat file
71+
labels = mat['classlabel']
72+
73+
# Save the data and labels to a .mat file
74+
result_filename = 'mymat_raw/A0{}{}.mat'.format(nSub, mode_str)
75+
savemat(result_filename, {'data': filtered_data, 'label': labels})
76+
77+
dir_path = './'
78+
# prepare train dataset
79+
changeGdf2Mat(dir_path, 'train')
80+
# prepare test dataset
81+
changeGdf2Mat(dir_path, 'eval')

0 commit comments

Comments
 (0)