Files
bci_algo/SSVEP/dwfbcca.py
2026-06-05 09:34:29 +08:00

530 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
import os
import time
import warnings
from os import error
import numpy as np
import scipy
from numpy.linalg import linalg
from scipy.io import loadmat
from scipy.linalg import qr
from scipy.signal import filtfilt, lfilter
# from numpy.linalg import _umath_linalg
class FbccaDw:
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
print('******************************************')
print('parameter list')
print('target:', num_target)
print('number of filter bank:', num_filter)
print('parameter:', parameter)
print('width:', width)
self.phase = 0
self.bandWidth = width
self.winNum = winNum
self.num_harms = num_harms
self.num_target = num_target
self.num_chans = num_chans
self.winTimeDelay = stimTime
self.fs = fs
self.winMaxSampleNum = self.winTimeDelay * self.winNum * self.fs
self.winDelayNum = round(self.winTimeDelay * self.fs)
self.num_fbs = num_filter
parameterValue = np.power(np.arange(1, self.num_fbs + 1), -(parameter[0])) + parameter[1]
self.weightValue = parameterValue / (sum(parameterValue))
self.dataUseLen = [0] * self.winNum
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
self.rhoNum = 2
self.notchZh = [0]
self.filterZf = [0] * self.num_fbs
self.north_b = []
self.north_a = []
self.filterBank_A = []
self.filterBank_B = []
self.winStep = 1
self.DW_cost_method = 'DW11' if method==1 else 'DW1'
'''
filterFrequenceBank根据刺激频率生成的通带和阻带用于滤波器组频带分解
'''
def filterFrequenceBank(self):
# 阻带的最高频率
lastFrequence = 90
freqBandWidth = self.bandWidth[1]
fStep = self.bandWidth[0]
bandFrequence = np.zeros((5, 4))
# 第二列频率带
band = list(range(freqBandWidth, lastFrequence, fStep))
band[:] = [x - 2 for x in band]
colValue = np.maximum(np.asmatrix(band), 1)
bandFrequence[:, 1] = colValue[0, 0:5]
# 第一列频率带
bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
# 第三列频率带
bandFrequence[:, 2] = lastFrequence + 2
# 第四列频率带
bandFrequence[:, 3] = bandFrequence[:, 2] + 10
# bandFrequence = np.array([[30,33,77,82],
# [62,68,77,82]])
for idx_fb in range(self.num_fbs):
Nq = self.fs / 2
Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
[N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
[B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
self.filterBank_A.append(A)
self.filterBank_B.append(B)
# def filterFrequenceBank(self):
# # 阻带的最高频率
# lastFrequence = 90
# freqBandWidth = self.bandWidth[1]
# fStep = self.bandWidth[0]
# bandFrequence = np.zeros((5, 4))
# # 第二列频率带
# band = list(range(freqBandWidth, lastFrequence, fStep))
# band[:] = [x - 2 for x in band]
# colValue = np.maximum(np.asmatrix(band), 1)
# bandFrequence[:, 1] = colValue[0, 0:5]
# # 第一列频率带
# bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
# # 第三列频率带
# bandFrequence[:, 2] = lastFrequence + 2
# # 第四列频率带
# bandFrequence[:, 3] = bandFrequence[:, 2] + 10
# for idx_fb in range(self.num_fbs):
# Nq = self.fs / 2
# Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
# Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
# [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
# 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
# [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
# self.filterBank_A.append(A)
# self.filterBank_B.append(B)
'''
Filter bank analysis
Input:
eeg : Input eeg data (# of targets, # of channels, Data length [sample])
Output:
filterData : Generated filter Data
'''
def filterbank(self, eeg):
filterData = np.zeros((self.num_fbs, eeg.shape[1], eeg.shape[0]))
for filterIndex in range(self.num_fbs):
if np.all(self.filterZf[filterIndex] == 0):
zi = np.zeros(
[max(len(self.filterBank_A[filterIndex]), len(self.filterBank_B[filterIndex])) - 1, self.num_chans])
_, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex],
eeg, zi=zi.T)
Data = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], eeg)
else:
Data, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex],
self.filterBank_A[filterIndex], eeg,
zi=self.filterZf[filterIndex])
filterData[filterIndex, :, :] = Data.T
return filterData
'''
process
矩阵的白化和QR正则化分解降低矩阵的维度加速计算时间
Input:
data : 输入的二维脑电信号
Output:
Q : 降维后的矩阵
rankQ :正则矩阵的秩
'''
def process(self, data):
# 白化操作
meanValue = np.asmatrix(data.mean(axis=1))
meanData = np.repeat(meanValue, data.shape[1], axis=1)
whiteTemp = data - meanData
# QR 分解
rankWhiteTemp = whiteTemp.shape[0]
whiteTemp = np.transpose(whiteTemp)
Q, R = qr(whiteTemp.A, mode='economic')
# 计算矩阵的秩
rankQ = linalg.matrix_rank(R)
if rankQ == 0:
raise ValueError('stats:canoncorr:badData')
elif rankQ <= rankWhiteTemp:
# warnings.warn('stats:canoncorr:NotFullRank')
Q = Q[:, 0:rankQ]
return Q, rankQ
'''
reference
Input:
listFreqs : 刺激频率列表
numberSmples : 用于分类的脑电信号采样点个数
num_harms : 谐波数
Output:
y_ref : 生成的参考信号 (刺激目标数, 2 * 谐波数, 数据长度/采样点数)
'''
def reference(self, listFreqs, numberSmples, num_harms):
numberFrequence = len(listFreqs)
timeIndex = np.arange(1, numberSmples + 1) / self.fs # time index
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
for frequenceIndex in range(numberFrequence):
temp = []
for harmIndex in range(1, num_harms + 1):
stimFrequence = listFreqs[frequenceIndex] # in HZ
# Sin and Cos
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
referenceTemp = np.asmatrix(temp)
# 白化操作和QR分解
Q, rankQ = self.process(referenceTemp)
referenceData[frequenceIndex] = np.transpose(Q)
return referenceData
'''
setNorthFilterPara
陷波器的参数初始化
self.north_b, self.north_a : 陷波器的参数设计
'''
def setNotchFilterPara(self):
# notchFilterNum = 3
# northFreq = 50
# bwDen = 35
# wo = northFreq / (self.fs / 2)
# bw = wo / bwDen
# self.north_b, self.north_a = iirnotch(wo, Q=35) # self.north_b, self.north_a = iircomb(northFreq, bwDen, 'notch')
# # n倍零极点相当于重复滤波n次
# if notchFilterNum > 1:
# z, p, k = tf2zpk(self.north_b, self.north_a)
# zNew = np.repeat(z, notchFilterNum, axis=0)
# zNew[1], zNew[4] = zNew[4], zNew[1]
# pNew = np.repeat(p, notchFilterNum, axis=0)
# pNew[1], pNew[4] = pNew[4], pNew[1]
# kNew = np.power(k, notchFilterNum)
# self.north_b, self.north_a = zpk2tf(zNew, pNew, kNew)
self.north_b = [0.94801603944125245604368501517456, -1.7577184027642647201616910024313,
3.9303778338832491279219993884908, -3.7392330345967859095424046245171,
3.9303778338832482397435796883656, -1.7577184027642638319832713023061,
0.94801603944125156786526531504933]
self.north_a = [1, -1.8214007435820627200939725298667, 4.0000101767406484043476666556671,
-3.7380998614928691026193519064691, 3.8589119784285759173769747576443,
-1.6951692350503837491970671180752, 0.89786559147978006745205448169145]
'''
northFilter
进行信号的50hz陷波处理
Input:
data :输入脑电数据
Output:
dataFiltered : 陷波处理后的脑电数据
'''
def northFilter(self, data):
try:
if np.all(self.notchZh[0] == 0):
zi = np.zeros([max(len(self.north_a), len(self.north_b)) - 1, self.num_chans])
_, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=zi.T)
dataFiltered = lfilter(self.north_b, self.north_a, data)
else:
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
return np.asmatrix(dataFiltered)
except Exception:
print(Exception)
'''
getDataQ
Inputs:
data脑电数据
Rbuffer待更新的中间系数
Output:
Qs1 : 脑电特征1
Qs2 : 脑电特征2
Rbuffer : 单窗口更新后的系数
'''
def getDataQ(self, data, Rbuffer):
Qs1 = [0] * self.num_fbs
Qs2 = [0] * self.num_fbs
nulldata = np.zeros([self.num_chans, self.num_chans])
Rnum = self.num_chans
for fb_num in range(self.num_fbs):
fb_data = np.squeeze(data[fb_num, :, :])
if np.all(Rbuffer[fb_num] == 0):
whiteTemp = fb_data
Q, R = qr(whiteTemp, mode='economic')
Qs1[fb_num] = nulldata
Qs2[fb_num] = Q
Rbuffer[fb_num] = R
else:
whiteTemp = np.concatenate((Rbuffer[fb_num], fb_data), axis=0)
Q, R = qr(whiteTemp, mode='economic')
Qs1[fb_num] = Q[0:Rnum, :]
Qs2[fb_num] = Q[Rnum:, :]
Rbuffer[fb_num] = R
return Qs1, Qs2, Rbuffer
'''
myCCA根据脑电特征和参考信号计算相关系数
Inputs:
dataQ脑电特征
Qc2y参考信号
d 相关系数取值数
Output:
rho : 相关系数
'''
def myCCA(self, dataQ, Qc2y, d):
if len(Qc2y) == 0:
Cov = dataQ
else:
Cov = np.dot(Qc2y, dataQ)
# U, S, V = scipy.linalg.svd(Cov, 0)
# rho = np.minimum(np.maximum(np.diag(S[0: d]).T, 0), 1)
# gufunc = _umath_linalg.svd_n
# rho = gufunc(Cov)
rho = np.linalg.svd(Cov, compute_uv=False)
return rho[0:d]
'''
weightCCA计算分类标签
Inputs:
Qs1脑电特征1
Qs2脑电特征2
ref 正余弦参考信号
Cxy 协方差中间参数
Output:
result : 分类标签
rho : 相关系数
Cxy : 更新后的协方差中间参数
'''
def weightCCA(self, Qs1, Qs2, ref, Cxy):
rMax = np.zeros([self.num_fbs, self.num_target])
for fi in range(self.num_fbs):
for si in range(self.num_target):
Qc2y = np.squeeze(ref[si, :, :])
# 更新协方差矩阵
if np.all(Cxy[fi][si] == 0):
Cxy[fi, si] = np.dot(Qc2y, Qs2[fi])
else:
Cxy[fi, si] = np.dot(Cxy[fi, si], Qs1[fi]) + np.dot(Qc2y, Qs2[fi])
r = self.myCCA(Cxy[fi, si], [], self.rhoNum)
rMax[fi, si] = r[0]
rho = np.dot(self.weightValue, np.power(rMax, 2)) # weighted sum of r from all different filter banks' result
result = np.argmax(rho)
return result, rho, Cxy
'''
costF损失函数根据计算的相关系数生成决策值用于和阈值进行比较
Inputs:
rho相关系数
method相关系数计算参数
C 参数
Output:
decideValue : 决策阈值
'''
def costF(self, rho, method, C):
rho = rho.tolist()
rho.sort(reverse=True)
if method == 'DW1':
decideValue = (rho[0] - rho[1]) / (sum(rho) - self.num_target * np.log(sum(np.exp(rho))))
elif method == 'DW11':
decideValue = -(rho[0] - rho[1])
elif method == 'DW2':
decideValue = (rho[0] - C) / (rho[1] - rho[0])
return decideValue
'''
onlineInit将窗口长度相位值、中间参数初始化
'''
def onlineInit(self):
self.dataUseLen = [0] * self.winNum
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
self.phase = 0
'''
filterInit重置陷波器和滤波器的滤波参数
'''
def filterInit(self):
self.notchZh = [0]
self.filterZf = [0] * self.num_fbs
'''
warmFilter预热滤波器去刺激前的4S数据对陷波器和滤波器参数进行初始化迭代去除过渡带的效果
Inputs:
data预处理脑电数据
'''
def warmFilter(self, data):
# 降采样在采集前完成
temp = self.preprocessFilter(data) #预热陷波滤波器
# 滤波器组频带分解
filterData = self.filterbank(temp) #预热滤波器组
'''
myDownSample数据降采样
Inputs:
data脑电数据
n降采样的倍数
Output:
eegData2 : 降采样后的数据
'''
def myDownSample(self, data, n):
data = data[:8, self.phase:]
dataNum = data.shape[1]
remainNum = (dataNum - 1) % n
self.phase = n - 1 - remainNum
dataDowmSample = []
for value in data:
value = value[0:value.size:n]
dataDowmSample.append(value)
eegData2 = np.array(dataDowmSample).reshape([8, int(np.round(data.shape[1] / n))])
return eegData2
'''
preprocessFilter预处理调用函数降采样和陷波处理
Inputs:
data脑电数据
Output:
filterData : 降采样和陷波后的数据
'''
def preprocessFilter(self, data):
# data = self.myDownSample(data, 4)
# filterData = self.northFilter(data[:8, :])
filterData = self.northFilter(data[:, :])
return filterData
'''
fbccaDWMW分类函数对输入的脑电信号进行识别输出决策标签
Inputs:
testdata脑电数据
referenceData参考信号
tValue出决策阈值
Output:
res : 决策标签
rho_new相关系数
minEps得到的决策阈值
'''
# 动态窗算法主函数
def fbccaDWMW(self, testdata, referenceData, tValue,calculateCount):
t1 = time.time()
# try:
# 初始参数
res = -1
minEps = float("inf")
# 降采样和陷波器处理
northData = self.preprocessFilter(testdata)
newSampleNum = northData.shape[1]
# 数据大于延迟长度,则无法根据后面的规则更新窗口
if newSampleNum > self.winDelayNum:
error('need add window delay time')
# 防止秩小于导联数
if newSampleNum < self.num_chans:
warnings.warn('data shape is [%d %d] need more data' % (newSampleNum, northData.shape[0]))
# 滤波器组频带分解
filterData = self.filterbank(northData)
winMinTime = 0
# 计算每个窗口的结果
for wi in range(0, self.winNum, self.winStep):
# print('dataUseLen:',wi,calculateCount, self.dataUseLen)
if wi == 0:
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
else:
if self.dataUseLen[wi] == 0:
# 判断当前窗是否为新的窗口(因为每一次新的窗口进来时都会使上一个窗口datauseLen>50)
if self.dataUseLen[wi - self.winStep] > self.winDelayNum*self.winStep:
self.dataUseLen[wi] = newSampleNum
else:
# print('中断: ',wi,calculateCount)
break
else:
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
if self.dataUseLen[wi] > self.winMaxSampleNum:
self.dataUseLen[wi] = newSampleNum
self.Rbuffer[wi, :, :, :] = 0
self.Cxy[wi, :, :, :, :] = 0
Qs1, Qs2, self.Rbuffer[wi, :, :, :] = self.getDataQ(filterData, self.Rbuffer[wi, :, :, :])
si = self.dataUseLen[wi] - newSampleNum
ei = self.dataUseLen[wi]
ref = referenceData[:, :, si:ei]
# 更新协方差
predLabel_new, rho_new, self.Cxy[wi, :, :, :] = self.weightCCA(Qs1, Qs2, ref, self.Cxy[wi, :, :, :])
# 增加限制,数据长度不能太短
if self.dataUseLen[wi] > winMinTime * self.fs:
epsilon = self.costF(rho_new, self.DW_cost_method, C=0)
if epsilon < minEps:
minEps = epsilon
predLabel = predLabel_new
xxx = rho_new
if minEps < tValue:
res = predLabel
if time.time() - t1 > 0.2 and self.winStep < 16:
self.winStep = self.winStep * 2
# print(self.winStep, " ", time.time() - t1)
# if res != -1:
# print('--------------------- ',res,xxx,' --------------------------')
return res
if __name__ == '__main__':
# The number of sub-bands in filter bank analysis
fs = 250
num_chans = 8
num_target = 40
num_filterBank = 3
num_harm = 5
stimTime = 0.2 # 多窗口窗长
winNum = 50 # 窗口的个数
trials = 1
step = 50
res = -1
list_freqs = [8.0, 8.2, 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 10.2, 10.4, 10.6, 10.8, 11.0, 11.2, 11.4,
11.6, 11.8, 12.0, 12.2, 12.4, 12.6, 12.8, 13., 13.2, 13.4, 13.6, 13.8, 14., 14.2, 14.4, 14.6, 14.8,
15.0, 15.2, 15.4, 15.6, 15.8]
# 初始化对象
dw = FbccaDw(fs, num_target, num_chans, num_filterBank, num_harm, stimTime, [1.0, 0.3], [8, 8], winNum)
# frequenceband
dw.filterFrequenceBank()
referenceData = dw.reference(list_freqs, int(winNum * stimTime * fs), num_harm)
dw.setNotchFilterPara()
prelabels = np.zeros((1, 40))
coefficient = np.zeros([1, 1])
path = "D:\\工作相关\\项目代码\\SDK封装与测试\\八神BCI脑电信号SDK20230105\\双鹰SDK对比\\offline_data250\\"
for index in range(1, trials + 1):
D = loadmat(os.path.join(path + str(1) + '-warmData.mat'))
warmData = D['warmData']
dw.onlineInit()
dw.filterInit()
dw.warmFilter(warmData.T)
tagget_i = 0
for tagget_i in range(1, step + 1):
D = loadmat(os.path.join(path + str(1) + '-' + str(tagget_i) + '.mat'))
dataSlice = D['dataTemp']
res = dw.fbccaDWMW(dataSlice.T, referenceData, tValue=-0.2)
if res != -1:
break
prelabels[0, index - 1] = res + 1
print(index, '--', res + 1," 计算轮数", tagget_i)