diff --git a/Training/DataCreateUpdated.py b/Training/DataCreateUpdated.py new file mode 100644 index 0000000..48de1e2 --- /dev/null +++ b/Training/DataCreateUpdated.py @@ -0,0 +1,155 @@ +import numpy as np +import matplotlib.pyplot as plt +import os +from scipy.fft import fft +from string import digits +from datetime import datetime, timedelta +import requests + +from sklearn.preprocessing import LabelEncoder +from sklearn.preprocessing import OneHotEncoder + +# define example +labels = {'NONE' : 0, 'L_EYE' : 1, 'R_EYE' : 2, 'JAW_CLENCH' : 3, 'BROW_UP' : 4, 'BROW_DOWN': 5} +labelInts = np.array([0, 1, 2, 3, 4, 5]) +# integer encode +label_encoder = LabelEncoder() +integer_encoded = label_encoder.fit_transform(labelInts) +# binary encode +onehot_encoder = OneHotEncoder(sparse=False) +integer_encoded = integer_encoded.reshape(len(integer_encoded), 1) +onehot_encoded = onehot_encoder.fit_transform(integer_encoded) + +def getOneHot(label): + return onehot_encoded[labels[label]] + +def getOneHotLabels(): + output = {} + for label in labels: + output[label] = getOneHot(label) + return output + +def getTitle(recordingFile): + return recordingFile.split("-")[-1].split(".")[0].translate({ord(k): None for k in digits}) + +# pass none if dont want granularity +# pass none to dataLimit if want all the data +def getData(path, granularity, channels, dataLimit): + dataRaw = [] + dataStartLine = 6 + count = 0 + with open(path, 'r') as data_file: + for line in data_file: + if count >= dataStartLine: + dataRaw.append(line.strip().split(',')) + else: + count += 1 + dataRaw = np.char.strip(np.array(dataRaw)) + + dataChannels = dataRaw[:, 1:5] + timeChannels = dataRaw[:, 15] + + if granularity is None: + granularity = 1 + # the current channel of data + if dataLimit is None: + dataLimit = len(dataChannels) + + channelData = dataChannels[:,channels][:dataLimit:granularity].transpose() + y_channels = channelData.astype(float) + inds = np.arange(channelData.shape[1]) + t = np.array([datetime.strptime(time[11:],'%H:%M:%S.%f') for time in timeChannels]) + return y_channels,inds,t + +def getLabel(path): + dataRaw = [] + first = True + basetime = None + with open(path, 'r') as label_file: + for line in label_file: + if(first): + dt_obj = datetime.strptime(line[11:].strip(),'%H:%M:%S.%f') + basetime = dt_obj + # dataRaw.append(dt_obj) + first = False + else: + dr = np.char.strip(np.array(line[1:-1].split(", "))) + # print(dr) + for i in range(len(dr)): + if dr[i] == '1': + dataRaw.append(basetime + timedelta(seconds=i)) + # print(dataRaw) + return dataRaw + + +def groupbyInterval(data, labels, interval, actionType): + #data tuple (x,y,z). labels: datetimes. interval(ms): int + y_channels,inds,t = data + interval_ms = timedelta(milliseconds=interval) + + split_inds = [] + cutoff_times = [t[0]+interval_ms] + for ind in range(t.shape[0]): + time = t[ind] + if time >= cutoff_times[-1]: + split_inds.append(ind) + cutoff_times.append(cutoff_times[-1] + interval_ms) + + ind_groups = np.split(inds, split_inds) + y_channels_groups = np.split(y_channels, split_inds, axis=1) + t_groups = np.split(t, split_inds) + + #find min group size + min_group_size = ind_groups[0].shape[0] + for i in range(len(split_inds)-1): + if ind_groups[i].shape[0] < min_group_size: + min_group_size = ind_groups[i].shape[0] + + #rectangularize jagged arrays + for i in range(len(split_inds)): + ind_groups[i] = ind_groups[i][:min_group_size] + y_channels_groups[i] = y_channels_groups[i][:,:min_group_size] + t_groups[i] = t_groups[i][:min_group_size] + + #drop short last group + ind_groups = np.array(ind_groups[:-1]) + y_channels_groups = np.array(y_channels_groups[:-1]).transpose((1,0,2)) + t_groups = np.array(t_groups[:-1]) + + #assign labels to groups + NO_ACTION = 0 + ACTION = 1 + + if actionType: + NO_ACTION = getOneHot("NONE") + ACTION = getOneHot(actionType) + + l_groups = np.array([NO_ACTION] * ind_groups.shape[0]) + + lnum=0 + for ind in range(len(cutoff_times)): + if lnum==len(labels): + break + + cutoff_time = cutoff_times[ind] + if labels[lnum] < cutoff_time: + l_groups[ind] = ACTION + lnum+=1 + + return y_channels_groups, ind_groups, t_groups, l_groups + + +#THIS IS THE MAIN METHOD FOR INTERACTION +# Inputs: +# datafile, labelfile, interval, channels requested +def getObservations(dataPath, labelPath, interval, channels, actionType): + data = getData(dataPath, None, channels, None) + action_times = getLabel(labelPath) + observations = groupbyInterval(data, action_times, interval, actionType) + + return observations + + + + + diff --git a/calibration.ipynb b/calibration.ipynb new file mode 100644 index 0000000..c10eff0 --- /dev/null +++ b/calibration.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "#Import libraries and function loading data\n", + "\n", + "from Training import DataCreateUpdated as dc\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import sklearn\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.model_selection import train_test_split\n", + "import random\n", + "\n", + "# rec_paths = [\"OpenBCI-RAW-2021-12-02_19-19-53.txt\", \n", + "# \"OpenBCI-RAW-2021-12-02_19-31-12.txt\", \n", + "# \"Recordings/Spring_2022/OpenBCISession_2022-02-16_Evan_JawClench_2/OpenBCI-RAW-2022-02-16_19-30-28.txt\",\n", + "# \"Recordings/Spring_2022/OpenBCISession_2022-02-16_Evan_LeftBlink_1/OpenBCI-RAW-2022-02-16_19-33-20.txt\",\n", + "# \"Recordings/Spring_2022/OpenBCISession_2022-02-16_Evan_LeftBlink_2/OpenBCI-RAW-2022-02-16_19-39-23.txt\",\n", + "# \"Recordings/Spring_2022/OpenBCISession_2022-02-16_Evan_RightBlink_1/OpenBCI-RAW-2022-02-16_19-42-03.txt\"]\n", + "# label_paths = [\"JawClench_labels_Ansh_12-02-21-1918.txt\", \n", + "# \"JawClench_labels_Ansh_12-02-21-1930.txt\", \n", + "# \"Recordings/Spring_2022/Evan_JawClench_1.txt\",\n", + "# \"Recordings/Spring_2022/Evan_LeftBlink_1.txt\",\n", + "# \"Recordings/Spring_2022/Evan_LeftBlink_2.txt\",\n", + "# \"Recordings/Spring_2022/Evan_RightBlink_1.txt\"]\n", + "# label_types = [\"JAW_CLENCH\", \n", + "# \"JAW_CLENCH\", \n", + "# \"JAW_CLENCH\",\n", + "# \"L_EYE\",\n", + "# \"L_EYE\",\n", + "# \"R_EYE\"]\n", + "\n", + "\n", + "rec_paths = [\"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_BrowLower_2/OpenBCI-RAW-2022-03-23_20-43-17.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_BrowLower_1/OpenBCI-RAW-2022-03-23_20-41-22.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_BrowRaise_1/OpenBCI-RAW-2022-03-23_20-37-06.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_BrowRaise_2/OpenBCI-RAW-2022-03-23_20-39-04.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_JawClench_1/OpenBCI-RAW-2022-03-23_20-33-09.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_JawClench_2/OpenBCI-RAW-2022-03-23_20-34-58.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_LeftEye_1/OpenBCI-RAW-2022-03-23_20-21-37.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_LeftEye_2/OpenBCI-RAW-2022-03-23_20-25-30.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_RightEye_1/OpenBCI-RAW-2022-03-23_20-27-22.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_RightEye_2/OpenBCI-RAW-2022-03-23_20-29-16.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_BrowLower_1/OpenBCI-RAW-2022-03-26_18-33-15.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_BrowLower_2/OpenBCI-RAW-2022-03-26_18-35-00.txt\",\n", + " # \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_BrowRaise_1/OpenBCI-RAW-2022-03-26_18-38-07.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_BrowRaise_2/OpenBCI-RAW-2022-03-26_18-39-58.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_JawClench_1/OpenBCI-RAW-2022-03-26_18-21-33.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_JawClench_2/OpenBCI-RAW-2022-03-26_18-23-24.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_LeftBlink_1/OpenBCI-RAW-2022-03-26_18-25-33.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_LeftBlink_2/OpenBCI-RAW-2022-03-26_18-30-33.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_RightBlink_1/OpenBCI-RAW-2022-03-26_18-04-32.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_RightBlink_2/OpenBCI-RAW-2022-03-26_18-07-12.txt\"]\n", + "\n", + "label_paths = [\"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_BrowLower_2/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_BrowLower_1/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_BrowRaise_1/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_BrowRaise_2/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_JawClench_1/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_JawClench_2/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_LeftEye_1/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_LeftEye_2/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_RightEye_1/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-23_Sam_RightEye_2/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_BrowLower_1/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_BrowLower_2/labels.txt\",\n", + " # \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_BrowRaise_1/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_BrowRaise_2/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_JawClench_1/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_JawClench_2/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_LeftBlink_1/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_LeftBlink_2/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_RightBlink_1/labels.txt\",\n", + " \"Recordings/Spring_2022/OpenBCISession_2022-03-26_Sam_RightBlink_2/labels.txt\"]\n", + "\n", + "label_types = [\"BROW_DOWN\", \"BROW_DOWN\", \"BROW_UP\", \"BROW_UP\", \"JAW_CLENCH\", \"JAW_CLENCH\", \"L_EYE\", \"L_EYE\", \"R_EYE\", \"R_EYE\", \"BROW_DOWN\", \"BROW_DOWN\", \"BROW_UP\", \"JAW_CLENCH\", \"JAW_CLENCH\", \"L_EYE\", \"L_EYE\", \"R_EYE\", \"R_EYE\"]\n", + "\n", + "#Chooses which input data to use\n", + "inputsToUse = np.arange(len(rec_paths))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def getMeanAbsDeviation(dchannel):\n", + " return np.transpose(np.mean(np.abs(dchannel - np.mean(dchannel, axis=2, keepdims=True)), axis=2))\n", + "\n", + "def getMeanSquaredDeviation(dchannel):\n", + " return np.transpose(np.mean(np.square(dchannel - np.mean(dchannel, axis=2, keepdims=True)), axis=2))\n", + "\n", + "def getMean(dchannel):\n", + " return np.transpose(np.mean(dchannel, axis=2))\n", + "\n", + "def getPercentile(dchannel, percent):\n", + " return np.transpose(np.percentile(dchannel, percent, axis=2))\n", + "\n", + "def getPercentile10(dchannel, percent=10):\n", + " return np.transpose(np.percentile(dchannel, percent, axis=2))\n", + "\n", + "def getPercentile90(dchannel, percent=90):\n", + " return np.transpose(np.percentile(dchannel, percent, axis=2))\n", + "\n", + "def getPercentile15(dchannel, percent=15):\n", + " return np.transpose(np.percentile(dchannel, percent, axis=2))\n", + "\n", + "def getPercentile85(dchannel, percent=85):\n", + " return np.transpose(np.percentile(dchannel, percent, axis=2))\n", + "\n", + "def getPercentile5(dchannel, percent=5):\n", + " return np.transpose(np.percentile(dchannel, percent, axis=2))\n", + "\n", + "def getPercentile95(dchannel, percent=95):\n", + " return np.transpose(np.percentile(dchannel, percent, axis=2))\n", + "\n", + "\n", + "def getSpread(dchannel):\n", + " return np.transpose(np.max(dchannel, axis=2) - np.min(dchannel, axis=2))\n", + "\n", + "def getSpreadPercentile(dchannel, low=5, high=95):\n", + " return getPercentile(dchannel, high) - getPercentile(dchannel, low)\n", + "\n", + "def getPeakCount(dchannel, w=3):\n", + " ret = np.zeros((dchannel.shape[1], dchannel.shape[0]))\n", + " for ch in range(dchannel.shape[0]):\n", + " for sample in range(dchannel.shape[1]):\n", + " count = 0\n", + " for tind in range(w, dchannel.shape[2]-w):\n", + " isPeak = True\n", + " for x in range(1, w+1):\n", + " isPeak &= (dchannel[ch, sample, tind] > dchannel[ch, sample, tind-x] and dchannel[ch, sample, tind] > dchannel[ch, sample, tind+x])\n", + " if(isPeak):\n", + " count+=1\n", + " ret[sample, ch] = count\n", + " return ret\n", + "\n", + "# Loads in X and Y\n", + "functsList = [getMean, getMeanSquaredDeviation, getMeanAbsDeviation, getSpreadPercentile, getPeakCount, getPercentile10, getPercentile90]\n", + "\n", + "for featureFunc in functsList:\n", + " y = None\n", + " fX = None\n", + " for i in inputsToUse:\n", + " obs = dc.getObservations(rec_paths[i], label_paths[i], 1000, [0,1,2,3], label_types[i])\n", + " y_channels_groups = obs" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "#Calibration Testing for Individual EEG Channels\n", + "import numpy as np\n", + "\n", + "def Calibrate(data):\n", + " # Find the mean of the data\n", + " mean = getMean(data)\n", + " # Find the standard deviation of the data\n", + " std_abs = getMeanAbsDeviation(data)\n", + " std_squared = getMeanSquaredDeviation(data)\n", + " # Find the minimum and maximum values of the data\n", + " min = np.min(data)\n", + " max = np.max(data)\n", + " # Find the range of the data\n", + " range = max - min\n", + " # Find the offset\n", + " offset = mean - min\n", + " # Return the calibration values\n", + " return offset, mean, std_abs, std_squared" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "b484e7466d1310c7063c3e2acaced4c395fa7098aef2c94d8ed134d16efa77f6" + }, + "kernelspec": { + "display_name": "Python 3.7.11 ('neuro')", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.7.11" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}