2023/3/1-2023/3/4 脑机接口学习内容一览:
这一篇博客里,将对来自于kaggle的P300数据集(P300-Dataset)开展研究,并且简化了任务流程,从识别P300电位对应字符到识别是否出现P300电位,因此准确率较高,处理难度主要出现在陌生数据集的处理方面。
数据集:
本次代码的数据集具体情况如下:
channelrange 值为1-8,表示有8个通道fs 采样频率为250hztrial 试验次数为35次,data中包含每一个实验开始的样本点data.y 在数据y你就有了所有的标记,这些标记表示这个时间点是否与刺激巧合,而刺激是为了引起P300反应。如果该值为0,则没有应用刺激。如果该值为1,则应用了与预期字母不对应的行/列刺激。如果该值为2,则存在与目标字母对应的行/列。stims 刺激(具体意义未知)channelnames 通道名称:['FZ', 'CZ', 'P3', 'PZ', 'P4', 'PO7', 'PO8', 'OZ']samples 8个通道的采样点word 35次试验中每一次刺激出现的字母['TOKENMIRARJUJUYMANSOCINCOJUEGOQUESO']data.trial 35个试验中每个试验开始的样本点data.flash Flash由4个字段组成,[采样点id,持续时间,刺激,命中/nohit]。第一个是指数时间点,刺激开始的地方。下面是时间点上的持续时间,一个标记表示激活了什么刺激,如果刺激是目标字母,则命中/未命中。格式[7486, 31, 11, 1]
处理思路:
对于这个数据集,我们根据flash的最后一个字段区分注释。
注释分为两个类别1和2,1表示命中,2为未命中。
接下来将提取前一段数据的1、2段特征来对后面数据段是否命中做出判断。
以下是处理数据的步骤:
1.将flash整理为data的注释,并将其转化为events对象。
2.根据events构建epoch进行特征提取。
3.根据提取特征进行后半段数据的预测(是否为P300刺激段) 。
因此我们得到了这个提取数据的函数:
def Extractive(file):original_data = scio.loadmat(file)samples = original_data['samples'].transpose()ch_names = ['FZ', 'CZ', 'P3', 'PZ', 'P4', 'PO7', 'PO8', 'OZ']sfreq = 250info = mne.create_info(ch_names, sfreq)raw = mne.io.RawArray(samples, info)# raw.plot()# plt.pause(0)inside_data = original_data['data']# print(inside_data.dtype)# np.set_printoptions(threshold=np.inf)# print(inside_data['flash'][0][0])flash = inside_data['flash'][0][0]# 创建注释onset = flash[:, 0]/250description = flash[:, 3]annot = mne.Annotations(onset=onset, duration=0.125, description=description)# print(annot)raw.set_annotations(annot)# print(samples.shape)raw_train = raw.copy()raw_test = raw.copy()raw_train.crop(tmin=200, tmax=500)raw_test.crop(tmin=500, tmax=800)'''raw_train.plot()raw_test.plot()plt.pause(0)'''return raw_train, raw_test
结果分析:
在经过最麻烦的提取数据工作之后,我们得到了mne支持的raw数据。这给我们的研究带来了很大的便利,因为根据之前的博客代码实践:对脑电信号进行特征提取并分类(二分类)中的函数处理方式,我们可以直接将提取过后的数据套入相同的流程中进行处理(实际上我也这样做了),虽然用了这种看起来有些偷懒的方法,但最后得到的准确率与我想象的有较大差别——只有83%左右,按照我的思路来说,P300这么明显的电位识别率不应该在80%上下。
如下是输出的准确率,分别是SVC,随机森林以及决策树方法:
Accuracy score: 0.8309178743961353
Accuracy score: 0.8309178743961353
Accuracy score: 0.7246376811594203
经过思考,我列出了以下可能造成误差的原因:
1.数据的提取存在不足,有些提供的信息并没有被我用上。
2.在raw.plot()之后,我发现flash的数据标注存在一定的问题——部分1和2的标注是重合的。
3.对P300的了解不足,没有选取合适的数据段。
4.缺少预处理步骤。
准确率低的原因一部分是由于我的处理存在不足,同时也与我对数据集的不了解,数据集制作者没有提供足够详细的信息有关。希望下一次我可以有所进步,尝试一下功率谱之外的提取特征的方法。
完整代码:
import scipy.io as scio
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne_icalabel import label_components
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer
from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs, corrmap)
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifierdef Extractive(file):original_data = scio.loadmat(file)# print(original_data.keys())'''for i in original_data.keys():print('the', i, 'is', original_data.get(i))''''''在此数据集中有几个需要关注的关键字:channelrange 值为1-8,表示有8个通道fs 采样频率为250hztrial 试验次数为35次,data中包含每一个实验开始的样本点data.y 在数据y你就有了所有的标记,这些标记表示这个时间点是否与刺激巧合,而刺激是为了引起P300反应。如果该值为0,则没有应用刺激。如果该值为1,则应用了与预期字母不对应的行/列刺激。如果该值为2,则存在与目标字母对应的行/列。stims 刺激(具体意义未知)channelnames 通道名称:['FZ', 'CZ', 'P3', 'PZ', 'P4', 'PO7', 'PO8', 'OZ']samples 8个通道的采样点word 35次试验中每一次刺激出现的字母['TOKENMIRARJUJUYMANSOCINCOJUEGOQUESO']data.trial 35个试验中每个试验开始的样本点data.flash Flash由4个字段组成,[采样点id,持续时间,刺激,命中/nohit]。第一个是指数时间点,刺激开始的地方。下面是时间点上的持续时间,一个标记表示激活了什么刺激,如果刺激是目标字母,则命中/未命中。格式[7486, 31, 11, 1]'''# print(original_data['stims'].shape)samples = original_data['samples'].transpose()ch_names = ['FZ', 'CZ', 'P3', 'PZ', 'P4', 'PO7', 'PO8', 'OZ']sfreq = 250info = mne.create_info(ch_names, sfreq)raw = mne.io.RawArray(samples, info)# raw.plot()# plt.pause(0)inside_data = original_data['data']# print(inside_data.dtype)# np.set_printoptions(threshold=np.inf)# print(inside_data['flash'][0][0])flash = inside_data['flash'][0][0]"""根据flash的最后一个字段区分注释注释分为两个类别1和2,1表示命中,2为未命中接下来将提取前一段数据的1、2段特征来对后面数据段是否命中做出判断以下是处理数据的步骤:1.将flash整理为data的注释,并将其转化为events对象2.根据events构建epoch进行特征提取3.根据提取特征进行后半段数据的预测(是否为P300刺激段)"""# 创建注释onset = flash[:, 0]/250description = flash[:, 3]annot = mne.Annotations(onset=onset, duration=0.125, description=description)# print(annot)raw.set_annotations(annot)# print(samples.shape)raw_train = raw.copy()raw_test = raw.copy()raw_train.crop(tmin=200, tmax=500)raw_test.crop(tmin=500, tmax=800)'''raw_train.plot()raw_test.plot()plt.pause(0)'''return raw_train, raw_testdef transform(raw):event_id = {'1': 1, '2': 2, '1/2': 2}# 重设通道名称channel_types = {'FZ': 'eeg', 'CZ': 'eeg', 'P3': 'eeg', 'PZ': 'eeg','P4': 'eeg', 'PO7': 'eeg', 'PO8': 'eeg', 'OZ': 'eeg'}raw.set_channel_types(channel_types)events, _ = mne.events_from_annotations(raw, event_id=event_id) # 将注释转化为events# mne.viz.plot_events(events_train, event_id=event_id) # 绘制事件发生时间分布图# plot一下估计每一个事件持续的长度# raw.plot()# plt.pause(0)epochs = mne.Epochs(raw=raw, events=events, tmin=-0.125, tmax=0.875, event_id=event_id,preload=True, event_repeated='drop')return epochs, eventsdef eeg_power_band(epochs):"""该函数根据epochs的特定频段中的相对功率来创建eeg特征使用welch方法可得到83%左右正确率,使用multi方法只得到70%左右,效果不是很好"""# 特定频带FREQ_BANDS = {"delta": [0.5, 4.5],"theta": [4.5, 8.5],# "alpha": [8.5, 11.5],"sigma": [11.5, 15.5],"beta": [15.5, 30]}spectrum = epochs.compute_psd(method='welch', picks='eeg', fmin=0.5, fmax=30., n_fft=64, n_overlap=10)psds, freqs = spectrum.get_data(return_freqs=True)# 归一化 PSDspsds /= np.sum(psds, axis=-1, keepdims=True)X = []for fmin, fmax in FREQ_BANDS.values():psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)X.append(psds_band.reshape(len(psds), -1))# print(X)return np.concatenate(X, axis=1)def classification(epochs_train, epochs_test, event_id, k):if k == 0:pipe = pipe = make_pipeline(FunctionTransformer(eeg_power_band, validate=False),SVC(C=1.2, kernel='linear', random_state=42))elif k == 1:pipe = pipe = make_pipeline(FunctionTransformer(eeg_power_band, validate=False),RandomForestClassifier(n_estimators=100, random_state=42))elif k == 2:pipe = pipe = make_pipeline(FunctionTransformer(eeg_power_band, validate=False),DecisionTreeClassifier(random_state=42))# 训练y_train = epochs_train.events[:, 2]pipe.fit(epochs_train, y_train)# 预测y_pred = pipe.predict(epochs_test)# 评估准确率y_test = epochs_test.events[:, 2]acc = accuracy_score(y_test, y_pred)print("Accuracy score: {}".format(acc))def main():file = 'C:/Users/86136/Desktop/innovation/python_mne/data/archive/P300S01.mat'# 从数据集中提取数据raw_train, raw_test = Extractive(file)# 生成epochs和events对象epochs_train, events_train = transform(raw_train)epochs_test, events_test = transform(raw_test)'''特征工程:将两个事件的epoch综合展示'''event_id = {'1': 1, '2': 2}fig, (ax1, ax2) = plt.subplots(ncols=2)stages = sorted(event_id.keys())for ax, title, epochs in zip([ax1, ax2], ['train', 'test'], [epochs_train, epochs_test]):for stage, color in zip(stages, ['red', 'blue']):epochs[stage].plot_psd(area_mode=None, color=color, ax=ax,fmin=0.1, fmax=20., show=False,average=True, spatial_colors=False)ax.set(title=title, xlabel='Frequency (Hz)')ax2.set(ylabel='uV^2/hz (dB)')ax2.legend(ax2.lines[2::3], stages)plt.tight_layout()# plt.show()'''分类预测'''for k in range(3):classification(epochs_train, epochs_test, event_id, k)main()