init commit
This commit is contained in:
529
Debug_64ch_Decoder_Optimize/SSVEP/dwfbcca.py
Normal file
529
Debug_64ch_Decoder_Optimize/SSVEP/dwfbcca.py
Normal file
@@ -0,0 +1,529 @@
|
||||
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from os import error
|
||||
import numpy as np
|
||||
import scipy
|
||||
from numpy.linalg import linalg
|
||||
from scipy.io import loadmat
|
||||
from scipy.linalg import qr
|
||||
from scipy.signal import filtfilt, lfilter
|
||||
# from numpy.linalg import _umath_linalg
|
||||
|
||||
|
||||
class FbccaDw:
|
||||
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
||||
print('******************************************')
|
||||
print('parameter list')
|
||||
print('target:', num_target)
|
||||
print('number of filter bank:', num_filter)
|
||||
print('parameter:', parameter)
|
||||
print('width:', width)
|
||||
self.phase = 0
|
||||
self.bandWidth = width
|
||||
self.winNum = winNum
|
||||
self.num_harms = num_harms
|
||||
self.num_target = num_target
|
||||
self.num_chans = num_chans
|
||||
self.winTimeDelay = stimTime
|
||||
self.fs = fs
|
||||
self.winMaxSampleNum = self.winTimeDelay * self.winNum * self.fs
|
||||
self.winDelayNum = round(self.winTimeDelay * self.fs)
|
||||
self.num_fbs = num_filter
|
||||
parameterValue = np.power(np.arange(1, self.num_fbs + 1), -(parameter[0])) + parameter[1]
|
||||
self.weightValue = parameterValue / (sum(parameterValue))
|
||||
|
||||
self.dataUseLen = [0] * self.winNum
|
||||
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
|
||||
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
|
||||
self.rhoNum = 2
|
||||
self.notchZh = [0]
|
||||
self.filterZf = [0] * self.num_fbs
|
||||
self.north_b = []
|
||||
self.north_a = []
|
||||
self.filterBank_A = []
|
||||
self.filterBank_B = []
|
||||
self.winStep = 1
|
||||
self.DW_cost_method = 'DW11' if method==1 else 'DW1'
|
||||
|
||||
'''
|
||||
filterFrequenceBank:根据刺激频率生成的通带和阻带,用于滤波器组频带分解
|
||||
'''
|
||||
|
||||
def filterFrequenceBank(self):
|
||||
# 阻带的最高频率
|
||||
lastFrequence = 90
|
||||
freqBandWidth = self.bandWidth[1]
|
||||
fStep = self.bandWidth[0]
|
||||
bandFrequence = np.zeros((5, 4))
|
||||
# 第二列频率带
|
||||
band = list(range(freqBandWidth, lastFrequence, fStep))
|
||||
band[:] = [x - 2 for x in band]
|
||||
colValue = np.maximum(np.asmatrix(band), 1)
|
||||
bandFrequence[:, 1] = colValue[0, 0:5]
|
||||
# 第一列频率带
|
||||
bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
|
||||
# 第三列频率带
|
||||
bandFrequence[:, 2] = lastFrequence + 2
|
||||
# 第四列频率带
|
||||
bandFrequence[:, 3] = bandFrequence[:, 2] + 10
|
||||
# bandFrequence = np.array([[30,33,77,82],
|
||||
# [62,68,77,82]])
|
||||
for idx_fb in range(self.num_fbs):
|
||||
Nq = self.fs / 2
|
||||
Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
|
||||
Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
|
||||
[N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
|
||||
40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
|
||||
[B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
|
||||
self.filterBank_A.append(A)
|
||||
self.filterBank_B.append(B)
|
||||
# def filterFrequenceBank(self):
|
||||
# # 阻带的最高频率
|
||||
# lastFrequence = 90
|
||||
# freqBandWidth = self.bandWidth[1]
|
||||
# fStep = self.bandWidth[0]
|
||||
# bandFrequence = np.zeros((5, 4))
|
||||
# # 第二列频率带
|
||||
# band = list(range(freqBandWidth, lastFrequence, fStep))
|
||||
# band[:] = [x - 2 for x in band]
|
||||
# colValue = np.maximum(np.asmatrix(band), 1)
|
||||
# bandFrequence[:, 1] = colValue[0, 0:5]
|
||||
# # 第一列频率带
|
||||
# bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
|
||||
# # 第三列频率带
|
||||
# bandFrequence[:, 2] = lastFrequence + 2
|
||||
# # 第四列频率带
|
||||
# bandFrequence[:, 3] = bandFrequence[:, 2] + 10
|
||||
# for idx_fb in range(self.num_fbs):
|
||||
# Nq = self.fs / 2
|
||||
# Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
|
||||
# Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
|
||||
# [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
|
||||
# 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
|
||||
# [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
|
||||
# self.filterBank_A.append(A)
|
||||
# self.filterBank_B.append(B)
|
||||
|
||||
'''
|
||||
Filter bank analysis
|
||||
Input:
|
||||
eeg : Input eeg data (# of targets, # of channels, Data length [sample])
|
||||
Output:
|
||||
filterData : Generated filter Data
|
||||
'''
|
||||
|
||||
def filterbank(self, eeg):
|
||||
filterData = np.zeros((self.num_fbs, eeg.shape[1], eeg.shape[0]))
|
||||
for filterIndex in range(self.num_fbs):
|
||||
if np.all(self.filterZf[filterIndex] == 0):
|
||||
zi = np.zeros(
|
||||
[max(len(self.filterBank_A[filterIndex]), len(self.filterBank_B[filterIndex])) - 1, self.num_chans])
|
||||
_, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex],
|
||||
eeg, zi=zi.T)
|
||||
Data = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], eeg)
|
||||
else:
|
||||
Data, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex],
|
||||
self.filterBank_A[filterIndex], eeg,
|
||||
zi=self.filterZf[filterIndex])
|
||||
filterData[filterIndex, :, :] = Data.T
|
||||
return filterData
|
||||
|
||||
'''
|
||||
process
|
||||
矩阵的白化和QR正则化分解,降低矩阵的维度,加速计算时间
|
||||
Input:
|
||||
data : 输入的二维脑电信号
|
||||
Output:
|
||||
Q : 降维后的矩阵
|
||||
rankQ :正则矩阵的秩
|
||||
'''
|
||||
|
||||
def process(self, data):
|
||||
# 白化操作
|
||||
meanValue = np.asmatrix(data.mean(axis=1))
|
||||
meanData = np.repeat(meanValue, data.shape[1], axis=1)
|
||||
whiteTemp = data - meanData
|
||||
# QR 分解
|
||||
rankWhiteTemp = whiteTemp.shape[0]
|
||||
whiteTemp = np.transpose(whiteTemp)
|
||||
Q, R = qr(whiteTemp.A, mode='economic')
|
||||
# 计算矩阵的秩
|
||||
rankQ = linalg.matrix_rank(R)
|
||||
if rankQ == 0:
|
||||
raise ValueError('stats:canoncorr:badData')
|
||||
elif rankQ <= rankWhiteTemp:
|
||||
# warnings.warn('stats:canoncorr:NotFullRank')
|
||||
Q = Q[:, 0:rankQ]
|
||||
return Q, rankQ
|
||||
|
||||
'''
|
||||
reference
|
||||
Input:
|
||||
listFreqs : 刺激频率列表
|
||||
numberSmples : 用于分类的脑电信号采样点个数
|
||||
num_harms : 谐波数
|
||||
Output:
|
||||
y_ref : 生成的参考信号 (刺激目标数, 2 * 谐波数, 数据长度/采样点数)
|
||||
'''
|
||||
|
||||
def reference(self, listFreqs, numberSmples, num_harms):
|
||||
numberFrequence = len(listFreqs)
|
||||
timeIndex = np.arange(1, numberSmples + 1) / self.fs # time index
|
||||
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
|
||||
for frequenceIndex in range(numberFrequence):
|
||||
temp = []
|
||||
for harmIndex in range(1, num_harms + 1):
|
||||
stimFrequence = listFreqs[frequenceIndex] # in HZ
|
||||
# Sin and Cos
|
||||
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
|
||||
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
|
||||
referenceTemp = np.asmatrix(temp)
|
||||
# 白化操作和QR分解
|
||||
Q, rankQ = self.process(referenceTemp)
|
||||
referenceData[frequenceIndex] = np.transpose(Q)
|
||||
return referenceData
|
||||
|
||||
'''
|
||||
setNorthFilterPara
|
||||
陷波器的参数初始化
|
||||
self.north_b, self.north_a : 陷波器的参数设计
|
||||
'''
|
||||
|
||||
def setNotchFilterPara(self):
|
||||
# notchFilterNum = 3
|
||||
# northFreq = 50
|
||||
# bwDen = 35
|
||||
# wo = northFreq / (self.fs / 2)
|
||||
# bw = wo / bwDen
|
||||
# self.north_b, self.north_a = iirnotch(wo, Q=35) # self.north_b, self.north_a = iircomb(northFreq, bwDen, 'notch')
|
||||
# # n倍零极点,相当于重复滤波n次
|
||||
# if notchFilterNum > 1:
|
||||
# z, p, k = tf2zpk(self.north_b, self.north_a)
|
||||
# zNew = np.repeat(z, notchFilterNum, axis=0)
|
||||
# zNew[1], zNew[4] = zNew[4], zNew[1]
|
||||
# pNew = np.repeat(p, notchFilterNum, axis=0)
|
||||
# pNew[1], pNew[4] = pNew[4], pNew[1]
|
||||
# kNew = np.power(k, notchFilterNum)
|
||||
# self.north_b, self.north_a = zpk2tf(zNew, pNew, kNew)
|
||||
self.north_b = [0.94801603944125245604368501517456, -1.7577184027642647201616910024313,
|
||||
3.9303778338832491279219993884908, -3.7392330345967859095424046245171,
|
||||
3.9303778338832482397435796883656, -1.7577184027642638319832713023061,
|
||||
0.94801603944125156786526531504933]
|
||||
|
||||
self.north_a = [1, -1.8214007435820627200939725298667, 4.0000101767406484043476666556671,
|
||||
-3.7380998614928691026193519064691, 3.8589119784285759173769747576443,
|
||||
-1.6951692350503837491970671180752, 0.89786559147978006745205448169145]
|
||||
|
||||
'''
|
||||
northFilter
|
||||
进行信号的50hz陷波处理
|
||||
Input:
|
||||
data :输入脑电数据
|
||||
Output:
|
||||
dataFiltered : 陷波处理后的脑电数据
|
||||
'''
|
||||
|
||||
def northFilter(self, data):
|
||||
try:
|
||||
if np.all(self.notchZh[0] == 0):
|
||||
zi = np.zeros([max(len(self.north_a), len(self.north_b)) - 1, self.num_chans])
|
||||
_, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=zi.T)
|
||||
dataFiltered = lfilter(self.north_b, self.north_a, data)
|
||||
else:
|
||||
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
||||
return np.asmatrix(dataFiltered)
|
||||
except Exception:
|
||||
print(Exception)
|
||||
|
||||
'''
|
||||
getDataQ
|
||||
Inputs:
|
||||
data:脑电数据
|
||||
Rbuffer:待更新的中间系数
|
||||
Output:
|
||||
Qs1 : 脑电特征1
|
||||
Qs2 : 脑电特征2
|
||||
Rbuffer : 单窗口更新后的系数
|
||||
|
||||
'''
|
||||
|
||||
def getDataQ(self, data, Rbuffer):
|
||||
Qs1 = [0] * self.num_fbs
|
||||
Qs2 = [0] * self.num_fbs
|
||||
nulldata = np.zeros([self.num_chans, self.num_chans])
|
||||
Rnum = self.num_chans
|
||||
for fb_num in range(self.num_fbs):
|
||||
fb_data = np.squeeze(data[fb_num, :, :])
|
||||
if np.all(Rbuffer[fb_num] == 0):
|
||||
whiteTemp = fb_data
|
||||
Q, R = qr(whiteTemp, mode='economic')
|
||||
Qs1[fb_num] = nulldata
|
||||
Qs2[fb_num] = Q
|
||||
Rbuffer[fb_num] = R
|
||||
else:
|
||||
whiteTemp = np.concatenate((Rbuffer[fb_num], fb_data), axis=0)
|
||||
Q, R = qr(whiteTemp, mode='economic')
|
||||
Qs1[fb_num] = Q[0:Rnum, :]
|
||||
Qs2[fb_num] = Q[Rnum:, :]
|
||||
Rbuffer[fb_num] = R
|
||||
return Qs1, Qs2, Rbuffer
|
||||
|
||||
'''
|
||||
myCCA:根据脑电特征和参考信号计算相关系数
|
||||
Inputs:
|
||||
dataQ:脑电特征
|
||||
Qc2y:参考信号
|
||||
d : 相关系数取值数
|
||||
Output:
|
||||
rho : 相关系数
|
||||
|
||||
'''
|
||||
|
||||
def myCCA(self, dataQ, Qc2y, d):
|
||||
if len(Qc2y) == 0:
|
||||
Cov = dataQ
|
||||
else:
|
||||
Cov = np.dot(Qc2y, dataQ)
|
||||
# U, S, V = scipy.linalg.svd(Cov, 0)
|
||||
# rho = np.minimum(np.maximum(np.diag(S[0: d]).T, 0), 1)
|
||||
# gufunc = _umath_linalg.svd_n
|
||||
# rho = gufunc(Cov)
|
||||
rho = np.linalg.svd(Cov, compute_uv=False)
|
||||
return rho[0:d]
|
||||
|
||||
'''
|
||||
weightCCA:计算分类标签
|
||||
Inputs:
|
||||
Qs1:脑电特征1
|
||||
Qs2:脑电特征2
|
||||
ref : 正余弦参考信号
|
||||
Cxy : 协方差中间参数
|
||||
Output:
|
||||
result : 分类标签
|
||||
rho : 相关系数
|
||||
Cxy : 更新后的协方差中间参数
|
||||
'''
|
||||
|
||||
def weightCCA(self, Qs1, Qs2, ref, Cxy):
|
||||
rMax = np.zeros([self.num_fbs, self.num_target])
|
||||
for fi in range(self.num_fbs):
|
||||
for si in range(self.num_target):
|
||||
Qc2y = np.squeeze(ref[si, :, :])
|
||||
# 更新协方差矩阵
|
||||
if np.all(Cxy[fi][si] == 0):
|
||||
Cxy[fi, si] = np.dot(Qc2y, Qs2[fi])
|
||||
else:
|
||||
Cxy[fi, si] = np.dot(Cxy[fi, si], Qs1[fi]) + np.dot(Qc2y, Qs2[fi])
|
||||
r = self.myCCA(Cxy[fi, si], [], self.rhoNum)
|
||||
rMax[fi, si] = r[0]
|
||||
rho = np.dot(self.weightValue, np.power(rMax, 2)) # weighted sum of r from all different filter banks' result
|
||||
result = np.argmax(rho)
|
||||
return result, rho, Cxy
|
||||
|
||||
'''
|
||||
costF:损失函数,根据计算的相关系数,生成决策值,用于和阈值进行比较
|
||||
Inputs:
|
||||
rho:相关系数
|
||||
method:相关系数计算参数
|
||||
C : 参数
|
||||
Output:
|
||||
decideValue : 决策阈值
|
||||
'''
|
||||
|
||||
def costF(self, rho, method, C):
|
||||
rho = rho.tolist()
|
||||
rho.sort(reverse=True)
|
||||
if method == 'DW1':
|
||||
decideValue = (rho[0] - rho[1]) / (sum(rho) - self.num_target * np.log(sum(np.exp(rho))))
|
||||
elif method == 'DW11':
|
||||
decideValue = -(rho[0] - rho[1])
|
||||
elif method == 'DW2':
|
||||
decideValue = (rho[0] - C) / (rho[1] - rho[0])
|
||||
return decideValue
|
||||
|
||||
'''
|
||||
onlineInit:将窗口长度,相位值、中间参数初始化
|
||||
'''
|
||||
|
||||
def onlineInit(self):
|
||||
self.dataUseLen = [0] * self.winNum
|
||||
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
|
||||
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
|
||||
self.phase = 0
|
||||
|
||||
'''
|
||||
filterInit:重置陷波器和滤波器的滤波参数
|
||||
'''
|
||||
|
||||
def filterInit(self):
|
||||
self.notchZh = [0]
|
||||
self.filterZf = [0] * self.num_fbs
|
||||
|
||||
'''
|
||||
warmFilter:预热滤波器,去刺激前的4S数据对陷波器和滤波器参数进行初始化迭代,去除过渡带的效果
|
||||
Inputs:
|
||||
data:预处理脑电数据
|
||||
'''
|
||||
|
||||
def warmFilter(self, data):
|
||||
# 降采样在采集前完成
|
||||
temp = self.preprocessFilter(data) #预热陷波滤波器
|
||||
# 滤波器组频带分解
|
||||
filterData = self.filterbank(temp) #预热滤波器组
|
||||
|
||||
'''
|
||||
myDownSample:数据降采样
|
||||
Inputs:
|
||||
data:脑电数据
|
||||
n:降采样的倍数
|
||||
Output:
|
||||
eegData2 : 降采样后的数据
|
||||
'''
|
||||
|
||||
def myDownSample(self, data, n):
|
||||
data = data[:8, self.phase:]
|
||||
dataNum = data.shape[1]
|
||||
remainNum = (dataNum - 1) % n
|
||||
self.phase = n - 1 - remainNum
|
||||
dataDowmSample = []
|
||||
for value in data:
|
||||
value = value[0:value.size:n]
|
||||
dataDowmSample.append(value)
|
||||
eegData2 = np.array(dataDowmSample).reshape([8, int(np.round(data.shape[1] / n))])
|
||||
return eegData2
|
||||
|
||||
'''
|
||||
preprocessFilter:预处理,调用函数降采样和陷波处理
|
||||
Inputs:
|
||||
data:脑电数据
|
||||
Output:
|
||||
filterData : 降采样和陷波后的数据
|
||||
'''
|
||||
|
||||
def preprocessFilter(self, data):
|
||||
# data = self.myDownSample(data, 4)
|
||||
# filterData = self.northFilter(data[:8, :])
|
||||
filterData = self.northFilter(data[:, :])
|
||||
return filterData
|
||||
|
||||
'''
|
||||
fbccaDWMW:分类函数,对输入的脑电信号进行识别,输出决策标签
|
||||
Inputs:
|
||||
testdata:脑电数据
|
||||
referenceData:参考信号
|
||||
tValue:出决策阈值
|
||||
Output:
|
||||
res : 决策标签
|
||||
rho_new:相关系数
|
||||
minEps:得到的决策阈值
|
||||
'''
|
||||
|
||||
# 动态窗算法主函数
|
||||
def fbccaDWMW(self, testdata, referenceData, tValue,calculateCount):
|
||||
t1 = time.time()
|
||||
# try:
|
||||
# 初始参数
|
||||
res = -1
|
||||
minEps = float("inf")
|
||||
# 降采样和陷波器处理
|
||||
northData = self.preprocessFilter(testdata)
|
||||
newSampleNum = northData.shape[1]
|
||||
# 数据大于延迟长度,则无法根据后面的规则更新窗口
|
||||
if newSampleNum > self.winDelayNum:
|
||||
error('need add window delay time')
|
||||
|
||||
# 防止秩小于导联数
|
||||
if newSampleNum < self.num_chans:
|
||||
warnings.warn('data shape is [%d %d] need more data' % (newSampleNum, northData.shape[0]))
|
||||
# 滤波器组频带分解
|
||||
filterData = self.filterbank(northData)
|
||||
winMinTime = 0
|
||||
# 计算每个窗口的结果
|
||||
for wi in range(0, self.winNum, self.winStep):
|
||||
# print('dataUseLen:',wi,calculateCount, self.dataUseLen)
|
||||
if wi == 0:
|
||||
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
|
||||
else:
|
||||
if self.dataUseLen[wi] == 0:
|
||||
# 判断当前窗是否为新的窗口(因为每一次新的窗口进来时,都会使上一个窗口datauseLen>50)
|
||||
if self.dataUseLen[wi - self.winStep] > self.winDelayNum*self.winStep:
|
||||
self.dataUseLen[wi] = newSampleNum
|
||||
else:
|
||||
# print('中断: ',wi,calculateCount)
|
||||
break
|
||||
else:
|
||||
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
|
||||
|
||||
if self.dataUseLen[wi] > self.winMaxSampleNum:
|
||||
self.dataUseLen[wi] = newSampleNum
|
||||
self.Rbuffer[wi, :, :, :] = 0
|
||||
self.Cxy[wi, :, :, :, :] = 0
|
||||
Qs1, Qs2, self.Rbuffer[wi, :, :, :] = self.getDataQ(filterData, self.Rbuffer[wi, :, :, :])
|
||||
si = self.dataUseLen[wi] - newSampleNum
|
||||
ei = self.dataUseLen[wi]
|
||||
ref = referenceData[:, :, si:ei]
|
||||
# 更新协方差
|
||||
predLabel_new, rho_new, self.Cxy[wi, :, :, :] = self.weightCCA(Qs1, Qs2, ref, self.Cxy[wi, :, :, :])
|
||||
# 增加限制,数据长度不能太短
|
||||
if self.dataUseLen[wi] > winMinTime * self.fs:
|
||||
epsilon = self.costF(rho_new, self.DW_cost_method, C=0)
|
||||
if epsilon < minEps:
|
||||
minEps = epsilon
|
||||
predLabel = predLabel_new
|
||||
xxx = rho_new
|
||||
if minEps < tValue:
|
||||
res = predLabel
|
||||
|
||||
if time.time() - t1 > 0.2 and self.winStep < 16:
|
||||
self.winStep = self.winStep * 2
|
||||
# print(self.winStep, " ", time.time() - t1)
|
||||
# if res != -1:
|
||||
# print('--------------------- ',res,xxx,' --------------------------')
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# The number of sub-bands in filter bank analysis
|
||||
fs = 250
|
||||
num_chans = 8
|
||||
num_target = 40
|
||||
num_filterBank = 3
|
||||
num_harm = 5
|
||||
stimTime = 0.2 # 多窗口窗长
|
||||
winNum = 50 # 窗口的个数
|
||||
trials = 1
|
||||
step = 50
|
||||
res = -1
|
||||
list_freqs = [8.0, 8.2, 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 10.2, 10.4, 10.6, 10.8, 11.0, 11.2, 11.4,
|
||||
11.6, 11.8, 12.0, 12.2, 12.4, 12.6, 12.8, 13., 13.2, 13.4, 13.6, 13.8, 14., 14.2, 14.4, 14.6, 14.8,
|
||||
15.0, 15.2, 15.4, 15.6, 15.8]
|
||||
# 初始化对象
|
||||
dw = FbccaDw(fs, num_target, num_chans, num_filterBank, num_harm, stimTime, [1.0, 0.3], [8, 8], winNum)
|
||||
# frequenceband
|
||||
dw.filterFrequenceBank()
|
||||
referenceData = dw.reference(list_freqs, int(winNum * stimTime * fs), num_harm)
|
||||
dw.setNotchFilterPara()
|
||||
|
||||
prelabels = np.zeros((1, 40))
|
||||
coefficient = np.zeros([1, 1])
|
||||
path = "D:\\工作相关\\项目代码\\SDK封装与测试\\八神BCI脑电信号SDK20230105\\双鹰SDK对比\\offline_data250\\"
|
||||
for index in range(1, trials + 1):
|
||||
D = loadmat(os.path.join(path + str(1) + '-warmData.mat'))
|
||||
warmData = D['warmData']
|
||||
dw.onlineInit()
|
||||
dw.filterInit()
|
||||
dw.warmFilter(warmData.T)
|
||||
|
||||
tagget_i = 0
|
||||
for tagget_i in range(1, step + 1):
|
||||
D = loadmat(os.path.join(path + str(1) + '-' + str(tagget_i) + '.mat'))
|
||||
dataSlice = D['dataTemp']
|
||||
res = dw.fbccaDWMW(dataSlice.T, referenceData, tValue=-0.2)
|
||||
if res != -1:
|
||||
break
|
||||
prelabels[0, index - 1] = res + 1
|
||||
print(index, '--', res + 1," 计算轮数", tagget_i)
|
||||
Reference in New Issue
Block a user