信号处理--使用EEGNet进行BCI脑电信号的分类

news/2024/11/8 16:48:16/

目录

理论

工具

方法实现

代码获取


理论

EEGNet作为一个比较成熟的框架,在BCI众多任务中,表现出不俗的性能。EEGNet 的主要特点包括:1)框架相对比较简单紧凑 2)适合许多的BCI脑电分析任务 3)使用两种卷积 Depth-wise convolution 和 separable convolution 实现普适特征的提取。

工具

Pytorch

P300 visual-evoked potentials数据集

error-related negativity responses (ERN) 数据集

movement-related cortical potentials (MRCP) 数据集

sensory motor rhythms (SMR) 数据集

方法实现

EEGNet模型定义

class EEGNet(nn.Module):def __init__(self):super(EEGNet, self).__init__()self.T = 120# Layer 1self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)self.batchnorm1 = nn.BatchNorm2d(16, False)# Layer 2self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))self.conv2 = nn.Conv2d(1, 4, (2, 32))self.batchnorm2 = nn.BatchNorm2d(4, False)self.pooling2 = nn.MaxPool2d(2, 4)# Layer 3self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))self.conv3 = nn.Conv2d(4, 4, (8, 4))self.batchnorm3 = nn.BatchNorm2d(4, False)self.pooling3 = nn.MaxPool2d((2, 4))# FC Layer# NOTE: This dimension will depend on the number of timestamps per sample in your data.# I have 120 timepoints. self.fc1 = nn.Linear(4*2*7, 1)def forward(self, x):# Layer 1x = F.elu(self.conv1(x))x = self.batchnorm1(x)x = F.dropout(x, 0.25)x = x.permute(0, 3, 1, 2)# Layer 2x = self.padding1(x)x = F.elu(self.conv2(x))x = self.batchnorm2(x)x = F.dropout(x, 0.25)x = self.pooling2(x)# Layer 3x = self.padding2(x)x = F.elu(self.conv3(x))x = self.batchnorm3(x)x = F.dropout(x, 0.25)x = self.pooling3(x)# FC Layerx = x.view(-1, 4*2*7)x = F.sigmoid(self.fc1(x))return xnet = EEGNet().cuda(0)
print net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0)))
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters())

 

 评估模型分类的相关指标

def evaluate(model, X, Y, params = ["acc"]):results = []batch_size = 100predicted = []for i in range(len(X)/batch_size):s = i*batch_sizee = i*batch_size+batch_sizeinputs = Variable(torch.from_numpy(X[s:e]).cuda(0))pred = model(inputs)predicted.append(pred.data.cpu().numpy())inputs = Variable(torch.from_numpy(X).cuda(0))predicted = model(inputs)predicted = predicted.data.cpu().numpy()for param in params:if param == 'acc':results.append(accuracy_score(Y, np.round(predicted)))if param == "auc":results.append(roc_auc_score(Y, predicted))if param == "recall":results.append(recall_score(Y, np.round(predicted)))if param == "precision":results.append(precision_score(Y, np.round(predicted)))if param == "fmeasure":precision = precision_score(Y, np.round(predicted))recall = recall_score(Y, np.round(predicted))results.append(2*precision*recall/ (precision+recall))return results

 模型的训练和测试

batch_size = 32for epoch in range(10):  # loop over the dataset multiple timesprint "\nEpoch ", epochrunning_loss = 0.0for i in range(len(X_train)/batch_size-1):s = i*batch_sizee = i*batch_size+batch_sizeinputs = torch.from_numpy(X_train[s:e])labels = torch.FloatTensor(np.array([y_train[s:e]]).T*1.0)# wrap them in Variableinputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.data[0]# Validation accuracyparams = ["acc", "auc", "fmeasure"]print paramsprint "Training Loss ", running_lossprint "Train - ", evaluate(net, X_train, y_train, params)print "Validation - ", evaluate(net, X_val, y_val, params)print "Test - ", evaluate(net, X_test, y_test, params)

模型提取部分特征的可视化

 

代码获取

信号处理-使用EEGNet进行BCI脑电信号的分类icon-default.png?t=N7T8https://download.csdn.net/download/YINTENAXIONGNAIER/89025247


http://www.ppmy.cn/news/1394642.html

相关文章

深入探索C语言动态内存分配:释放你的程序潜力

🌈大家好!我是Kevin,蠢蠢大一幼崽,很高兴你们可以来阅读我的博客! 🌟我热衷于分享🖊学习经验,🏫多彩生活,精彩足球赛事⚽ 🌟感谢大家的支持&#…

【boost_search搜索引擎】1.获取数据源

boost搜索引擎 1、项目介绍2、获取数据源 1、项目介绍 boost_search项目和百度那种不一样,百度是全站搜索,而boost_search是一个站内搜索。而项目的宏观上实现思路就如同图上的思路。 2、获取数据源 我们要实现一个站内搜索,我们就要有这…

浅易理解:YoloV3 案例_05

目标 熟悉TFRecord文件的使用方法YoloV3模型结构及构建方法数据处理方法利用yoloV3模型进行训练和预测 1.TFrecord文件 该案例中我们依然使用VOC数据集来进行目标检测,不同的是我们要利用tfrecord文件来存储和读取数据,首先来看一下tfrecord文件的相关…

[AIGC] Redis基础命令集详细介绍

Redis是一个强大的开源的键-值存储系统,被广泛应用于各种应用程序中。在使用Redis时,我们需要掌握一些基本的Redis命令来操作存储在其上的数据。这篇文章将向你介绍一些基本的Redis命令,让你能够更好地使用和理解Redis。 文章目录 启动Redis…

yolov5/v7修改标签和检测框显示【最全】

1. 背景介绍 在计算机视觉领域,目标检测是一个重要的任务,它旨在识别图像中的对象并定位它们的边界框。近年来,基于深度学习的目标检测算法取得了显著的进展,其中YOLO(You Only Look Once)系列算法因其速度…

nodejs+vue分类信息服务平台移动端的设计与实现-安卓pythonflask-django-php

分类信息服务平台设计的目的是为用户提供活动信息、活动记录等方面的平台。 与PC端应用程序相比,分类信息服务平台的设计主要面向于移动端,旨在为管理员和用户、商铺提供一个分类信息服务平台。用户可以通过Android及时查看活动信息等。 分类信息服务平台…

pytorch 实现多层神经网络MLP(Pytorch 05)

一 多层感知机 最简单的深度网络称为多层感知机。多层感知机由 多层神经元 组成,每一层与它的上一层相连,从中接收输入;同时每一层也与它的下一层相连,影响当前层的神经元。 softmax 实现了 如何处理数据,如何将 输出…

【Godot 3.5控件】用TextureProgress制作血条

说明 本文写自2022年11月13日-14日,内容基于Godot3.5。后续可能会进行向4.2版本的转化。 概述 之前基于ProgressBar创建过血条组件。它主要是基于修改StyleBoxFlat,好处是它几乎可以算是矢量的,体积小,所有东西都是样式信息&am…