530 lines
20 KiB
Python
530 lines
20 KiB
Python
|
|
|
|||
|
|
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
import warnings
|
|||
|
|
from os import error
|
|||
|
|
import numpy as np
|
|||
|
|
import scipy
|
|||
|
|
from numpy.linalg import linalg
|
|||
|
|
from scipy.io import loadmat
|
|||
|
|
from scipy.linalg import qr
|
|||
|
|
from scipy.signal import filtfilt, lfilter
|
|||
|
|
# from numpy.linalg import _umath_linalg
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FbccaDw:
|
|||
|
|
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
|||
|
|
print('******************************************')
|
|||
|
|
print('parameter list')
|
|||
|
|
print('target:', num_target)
|
|||
|
|
print('number of filter bank:', num_filter)
|
|||
|
|
print('parameter:', parameter)
|
|||
|
|
print('width:', width)
|
|||
|
|
self.phase = 0
|
|||
|
|
self.bandWidth = width
|
|||
|
|
self.winNum = winNum
|
|||
|
|
self.num_harms = num_harms
|
|||
|
|
self.num_target = num_target
|
|||
|
|
self.num_chans = num_chans
|
|||
|
|
self.winTimeDelay = stimTime
|
|||
|
|
self.fs = fs
|
|||
|
|
self.winMaxSampleNum = self.winTimeDelay * self.winNum * self.fs
|
|||
|
|
self.winDelayNum = round(self.winTimeDelay * self.fs)
|
|||
|
|
self.num_fbs = num_filter
|
|||
|
|
parameterValue = np.power(np.arange(1, self.num_fbs + 1), -(parameter[0])) + parameter[1]
|
|||
|
|
self.weightValue = parameterValue / (sum(parameterValue))
|
|||
|
|
|
|||
|
|
self.dataUseLen = [0] * self.winNum
|
|||
|
|
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
|
|||
|
|
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
|
|||
|
|
self.rhoNum = 2
|
|||
|
|
self.notchZh = [0]
|
|||
|
|
self.filterZf = [0] * self.num_fbs
|
|||
|
|
self.north_b = []
|
|||
|
|
self.north_a = []
|
|||
|
|
self.filterBank_A = []
|
|||
|
|
self.filterBank_B = []
|
|||
|
|
self.winStep = 1
|
|||
|
|
self.DW_cost_method = 'DW11' if method==1 else 'DW1'
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
filterFrequenceBank:根据刺激频率生成的通带和阻带,用于滤波器组频带分解
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def filterFrequenceBank(self):
|
|||
|
|
# 阻带的最高频率
|
|||
|
|
lastFrequence = 90
|
|||
|
|
freqBandWidth = self.bandWidth[1]
|
|||
|
|
fStep = self.bandWidth[0]
|
|||
|
|
bandFrequence = np.zeros((5, 4))
|
|||
|
|
# 第二列频率带
|
|||
|
|
band = list(range(freqBandWidth, lastFrequence, fStep))
|
|||
|
|
band[:] = [x - 2 for x in band]
|
|||
|
|
colValue = np.maximum(np.asmatrix(band), 1)
|
|||
|
|
bandFrequence[:, 1] = colValue[0, 0:5]
|
|||
|
|
# 第一列频率带
|
|||
|
|
bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
|
|||
|
|
# 第三列频率带
|
|||
|
|
bandFrequence[:, 2] = lastFrequence + 2
|
|||
|
|
# 第四列频率带
|
|||
|
|
bandFrequence[:, 3] = bandFrequence[:, 2] + 10
|
|||
|
|
# bandFrequence = np.array([[30,33,77,82],
|
|||
|
|
# [62,68,77,82]])
|
|||
|
|
for idx_fb in range(self.num_fbs):
|
|||
|
|
Nq = self.fs / 2
|
|||
|
|
Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
|
|||
|
|
Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
|
|||
|
|
[N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
|
|||
|
|
40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
|
|||
|
|
[B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
|
|||
|
|
self.filterBank_A.append(A)
|
|||
|
|
self.filterBank_B.append(B)
|
|||
|
|
# def filterFrequenceBank(self):
|
|||
|
|
# # 阻带的最高频率
|
|||
|
|
# lastFrequence = 90
|
|||
|
|
# freqBandWidth = self.bandWidth[1]
|
|||
|
|
# fStep = self.bandWidth[0]
|
|||
|
|
# bandFrequence = np.zeros((5, 4))
|
|||
|
|
# # 第二列频率带
|
|||
|
|
# band = list(range(freqBandWidth, lastFrequence, fStep))
|
|||
|
|
# band[:] = [x - 2 for x in band]
|
|||
|
|
# colValue = np.maximum(np.asmatrix(band), 1)
|
|||
|
|
# bandFrequence[:, 1] = colValue[0, 0:5]
|
|||
|
|
# # 第一列频率带
|
|||
|
|
# bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
|
|||
|
|
# # 第三列频率带
|
|||
|
|
# bandFrequence[:, 2] = lastFrequence + 2
|
|||
|
|
# # 第四列频率带
|
|||
|
|
# bandFrequence[:, 3] = bandFrequence[:, 2] + 10
|
|||
|
|
# for idx_fb in range(self.num_fbs):
|
|||
|
|
# Nq = self.fs / 2
|
|||
|
|
# Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
|
|||
|
|
# Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
|
|||
|
|
# [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
|
|||
|
|
# 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
|
|||
|
|
# [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
|
|||
|
|
# self.filterBank_A.append(A)
|
|||
|
|
# self.filterBank_B.append(B)
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
Filter bank analysis
|
|||
|
|
Input:
|
|||
|
|
eeg : Input eeg data (# of targets, # of channels, Data length [sample])
|
|||
|
|
Output:
|
|||
|
|
filterData : Generated filter Data
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def filterbank(self, eeg):
|
|||
|
|
filterData = np.zeros((self.num_fbs, eeg.shape[1], eeg.shape[0]))
|
|||
|
|
for filterIndex in range(self.num_fbs):
|
|||
|
|
if np.all(self.filterZf[filterIndex] == 0):
|
|||
|
|
zi = np.zeros(
|
|||
|
|
[max(len(self.filterBank_A[filterIndex]), len(self.filterBank_B[filterIndex])) - 1, self.num_chans])
|
|||
|
|
_, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex],
|
|||
|
|
eeg, zi=zi.T)
|
|||
|
|
Data = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], eeg)
|
|||
|
|
else:
|
|||
|
|
Data, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex],
|
|||
|
|
self.filterBank_A[filterIndex], eeg,
|
|||
|
|
zi=self.filterZf[filterIndex])
|
|||
|
|
filterData[filterIndex, :, :] = Data.T
|
|||
|
|
return filterData
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
process
|
|||
|
|
矩阵的白化和QR正则化分解,降低矩阵的维度,加速计算时间
|
|||
|
|
Input:
|
|||
|
|
data : 输入的二维脑电信号
|
|||
|
|
Output:
|
|||
|
|
Q : 降维后的矩阵
|
|||
|
|
rankQ :正则矩阵的秩
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def process(self, data):
|
|||
|
|
# 白化操作
|
|||
|
|
meanValue = np.asmatrix(data.mean(axis=1))
|
|||
|
|
meanData = np.repeat(meanValue, data.shape[1], axis=1)
|
|||
|
|
whiteTemp = data - meanData
|
|||
|
|
# QR 分解
|
|||
|
|
rankWhiteTemp = whiteTemp.shape[0]
|
|||
|
|
whiteTemp = np.transpose(whiteTemp)
|
|||
|
|
Q, R = qr(whiteTemp.A, mode='economic')
|
|||
|
|
# 计算矩阵的秩
|
|||
|
|
rankQ = linalg.matrix_rank(R)
|
|||
|
|
if rankQ == 0:
|
|||
|
|
raise ValueError('stats:canoncorr:badData')
|
|||
|
|
elif rankQ <= rankWhiteTemp:
|
|||
|
|
# warnings.warn('stats:canoncorr:NotFullRank')
|
|||
|
|
Q = Q[:, 0:rankQ]
|
|||
|
|
return Q, rankQ
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
reference
|
|||
|
|
Input:
|
|||
|
|
listFreqs : 刺激频率列表
|
|||
|
|
numberSmples : 用于分类的脑电信号采样点个数
|
|||
|
|
num_harms : 谐波数
|
|||
|
|
Output:
|
|||
|
|
y_ref : 生成的参考信号 (刺激目标数, 2 * 谐波数, 数据长度/采样点数)
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def reference(self, listFreqs, numberSmples, num_harms):
|
|||
|
|
numberFrequence = len(listFreqs)
|
|||
|
|
timeIndex = np.arange(1, numberSmples + 1) / self.fs # time index
|
|||
|
|
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
|
|||
|
|
for frequenceIndex in range(numberFrequence):
|
|||
|
|
temp = []
|
|||
|
|
for harmIndex in range(1, num_harms + 1):
|
|||
|
|
stimFrequence = listFreqs[frequenceIndex] # in HZ
|
|||
|
|
# Sin and Cos
|
|||
|
|
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
|
|||
|
|
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
|
|||
|
|
referenceTemp = np.asmatrix(temp)
|
|||
|
|
# 白化操作和QR分解
|
|||
|
|
Q, rankQ = self.process(referenceTemp)
|
|||
|
|
referenceData[frequenceIndex] = np.transpose(Q)
|
|||
|
|
return referenceData
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
setNorthFilterPara
|
|||
|
|
陷波器的参数初始化
|
|||
|
|
self.north_b, self.north_a : 陷波器的参数设计
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def setNotchFilterPara(self):
|
|||
|
|
# notchFilterNum = 3
|
|||
|
|
# northFreq = 50
|
|||
|
|
# bwDen = 35
|
|||
|
|
# wo = northFreq / (self.fs / 2)
|
|||
|
|
# bw = wo / bwDen
|
|||
|
|
# self.north_b, self.north_a = iirnotch(wo, Q=35) # self.north_b, self.north_a = iircomb(northFreq, bwDen, 'notch')
|
|||
|
|
# # n倍零极点,相当于重复滤波n次
|
|||
|
|
# if notchFilterNum > 1:
|
|||
|
|
# z, p, k = tf2zpk(self.north_b, self.north_a)
|
|||
|
|
# zNew = np.repeat(z, notchFilterNum, axis=0)
|
|||
|
|
# zNew[1], zNew[4] = zNew[4], zNew[1]
|
|||
|
|
# pNew = np.repeat(p, notchFilterNum, axis=0)
|
|||
|
|
# pNew[1], pNew[4] = pNew[4], pNew[1]
|
|||
|
|
# kNew = np.power(k, notchFilterNum)
|
|||
|
|
# self.north_b, self.north_a = zpk2tf(zNew, pNew, kNew)
|
|||
|
|
self.north_b = [0.94801603944125245604368501517456, -1.7577184027642647201616910024313,
|
|||
|
|
3.9303778338832491279219993884908, -3.7392330345967859095424046245171,
|
|||
|
|
3.9303778338832482397435796883656, -1.7577184027642638319832713023061,
|
|||
|
|
0.94801603944125156786526531504933]
|
|||
|
|
|
|||
|
|
self.north_a = [1, -1.8214007435820627200939725298667, 4.0000101767406484043476666556671,
|
|||
|
|
-3.7380998614928691026193519064691, 3.8589119784285759173769747576443,
|
|||
|
|
-1.6951692350503837491970671180752, 0.89786559147978006745205448169145]
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
northFilter
|
|||
|
|
进行信号的50hz陷波处理
|
|||
|
|
Input:
|
|||
|
|
data :输入脑电数据
|
|||
|
|
Output:
|
|||
|
|
dataFiltered : 陷波处理后的脑电数据
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def northFilter(self, data):
|
|||
|
|
try:
|
|||
|
|
if np.all(self.notchZh[0] == 0):
|
|||
|
|
zi = np.zeros([max(len(self.north_a), len(self.north_b)) - 1, self.num_chans])
|
|||
|
|
_, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=zi.T)
|
|||
|
|
dataFiltered = lfilter(self.north_b, self.north_a, data)
|
|||
|
|
else:
|
|||
|
|
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
|||
|
|
return np.asmatrix(dataFiltered)
|
|||
|
|
except Exception:
|
|||
|
|
print(Exception)
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
getDataQ
|
|||
|
|
Inputs:
|
|||
|
|
data:脑电数据
|
|||
|
|
Rbuffer:待更新的中间系数
|
|||
|
|
Output:
|
|||
|
|
Qs1 : 脑电特征1
|
|||
|
|
Qs2 : 脑电特征2
|
|||
|
|
Rbuffer : 单窗口更新后的系数
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def getDataQ(self, data, Rbuffer):
|
|||
|
|
Qs1 = [0] * self.num_fbs
|
|||
|
|
Qs2 = [0] * self.num_fbs
|
|||
|
|
nulldata = np.zeros([self.num_chans, self.num_chans])
|
|||
|
|
Rnum = self.num_chans
|
|||
|
|
for fb_num in range(self.num_fbs):
|
|||
|
|
fb_data = np.squeeze(data[fb_num, :, :])
|
|||
|
|
if np.all(Rbuffer[fb_num] == 0):
|
|||
|
|
whiteTemp = fb_data
|
|||
|
|
Q, R = qr(whiteTemp, mode='economic')
|
|||
|
|
Qs1[fb_num] = nulldata
|
|||
|
|
Qs2[fb_num] = Q
|
|||
|
|
Rbuffer[fb_num] = R
|
|||
|
|
else:
|
|||
|
|
whiteTemp = np.concatenate((Rbuffer[fb_num], fb_data), axis=0)
|
|||
|
|
Q, R = qr(whiteTemp, mode='economic')
|
|||
|
|
Qs1[fb_num] = Q[0:Rnum, :]
|
|||
|
|
Qs2[fb_num] = Q[Rnum:, :]
|
|||
|
|
Rbuffer[fb_num] = R
|
|||
|
|
return Qs1, Qs2, Rbuffer
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
myCCA:根据脑电特征和参考信号计算相关系数
|
|||
|
|
Inputs:
|
|||
|
|
dataQ:脑电特征
|
|||
|
|
Qc2y:参考信号
|
|||
|
|
d : 相关系数取值数
|
|||
|
|
Output:
|
|||
|
|
rho : 相关系数
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def myCCA(self, dataQ, Qc2y, d):
|
|||
|
|
if len(Qc2y) == 0:
|
|||
|
|
Cov = dataQ
|
|||
|
|
else:
|
|||
|
|
Cov = np.dot(Qc2y, dataQ)
|
|||
|
|
# U, S, V = scipy.linalg.svd(Cov, 0)
|
|||
|
|
# rho = np.minimum(np.maximum(np.diag(S[0: d]).T, 0), 1)
|
|||
|
|
# gufunc = _umath_linalg.svd_n
|
|||
|
|
# rho = gufunc(Cov)
|
|||
|
|
rho = np.linalg.svd(Cov, compute_uv=False)
|
|||
|
|
return rho[0:d]
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
weightCCA:计算分类标签
|
|||
|
|
Inputs:
|
|||
|
|
Qs1:脑电特征1
|
|||
|
|
Qs2:脑电特征2
|
|||
|
|
ref : 正余弦参考信号
|
|||
|
|
Cxy : 协方差中间参数
|
|||
|
|
Output:
|
|||
|
|
result : 分类标签
|
|||
|
|
rho : 相关系数
|
|||
|
|
Cxy : 更新后的协方差中间参数
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def weightCCA(self, Qs1, Qs2, ref, Cxy):
|
|||
|
|
rMax = np.zeros([self.num_fbs, self.num_target])
|
|||
|
|
for fi in range(self.num_fbs):
|
|||
|
|
for si in range(self.num_target):
|
|||
|
|
Qc2y = np.squeeze(ref[si, :, :])
|
|||
|
|
# 更新协方差矩阵
|
|||
|
|
if np.all(Cxy[fi][si] == 0):
|
|||
|
|
Cxy[fi, si] = np.dot(Qc2y, Qs2[fi])
|
|||
|
|
else:
|
|||
|
|
Cxy[fi, si] = np.dot(Cxy[fi, si], Qs1[fi]) + np.dot(Qc2y, Qs2[fi])
|
|||
|
|
r = self.myCCA(Cxy[fi, si], [], self.rhoNum)
|
|||
|
|
rMax[fi, si] = r[0]
|
|||
|
|
rho = np.dot(self.weightValue, np.power(rMax, 2)) # weighted sum of r from all different filter banks' result
|
|||
|
|
result = np.argmax(rho)
|
|||
|
|
return result, rho, Cxy
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
costF:损失函数,根据计算的相关系数,生成决策值,用于和阈值进行比较
|
|||
|
|
Inputs:
|
|||
|
|
rho:相关系数
|
|||
|
|
method:相关系数计算参数
|
|||
|
|
C : 参数
|
|||
|
|
Output:
|
|||
|
|
decideValue : 决策阈值
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def costF(self, rho, method, C):
|
|||
|
|
rho = rho.tolist()
|
|||
|
|
rho.sort(reverse=True)
|
|||
|
|
if method == 'DW1':
|
|||
|
|
decideValue = (rho[0] - rho[1]) / (sum(rho) - self.num_target * np.log(sum(np.exp(rho))))
|
|||
|
|
elif method == 'DW11':
|
|||
|
|
decideValue = -(rho[0] - rho[1])
|
|||
|
|
elif method == 'DW2':
|
|||
|
|
decideValue = (rho[0] - C) / (rho[1] - rho[0])
|
|||
|
|
return decideValue
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
onlineInit:将窗口长度,相位值、中间参数初始化
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def onlineInit(self):
|
|||
|
|
self.dataUseLen = [0] * self.winNum
|
|||
|
|
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
|
|||
|
|
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
|
|||
|
|
self.phase = 0
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
filterInit:重置陷波器和滤波器的滤波参数
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def filterInit(self):
|
|||
|
|
self.notchZh = [0]
|
|||
|
|
self.filterZf = [0] * self.num_fbs
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
warmFilter:预热滤波器,去刺激前的4S数据对陷波器和滤波器参数进行初始化迭代,去除过渡带的效果
|
|||
|
|
Inputs:
|
|||
|
|
data:预处理脑电数据
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def warmFilter(self, data):
|
|||
|
|
# 降采样在采集前完成
|
|||
|
|
temp = self.preprocessFilter(data) #预热陷波滤波器
|
|||
|
|
# 滤波器组频带分解
|
|||
|
|
filterData = self.filterbank(temp) #预热滤波器组
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
myDownSample:数据降采样
|
|||
|
|
Inputs:
|
|||
|
|
data:脑电数据
|
|||
|
|
n:降采样的倍数
|
|||
|
|
Output:
|
|||
|
|
eegData2 : 降采样后的数据
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def myDownSample(self, data, n):
|
|||
|
|
data = data[:8, self.phase:]
|
|||
|
|
dataNum = data.shape[1]
|
|||
|
|
remainNum = (dataNum - 1) % n
|
|||
|
|
self.phase = n - 1 - remainNum
|
|||
|
|
dataDowmSample = []
|
|||
|
|
for value in data:
|
|||
|
|
value = value[0:value.size:n]
|
|||
|
|
dataDowmSample.append(value)
|
|||
|
|
eegData2 = np.array(dataDowmSample).reshape([8, int(np.round(data.shape[1] / n))])
|
|||
|
|
return eegData2
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
preprocessFilter:预处理,调用函数降采样和陷波处理
|
|||
|
|
Inputs:
|
|||
|
|
data:脑电数据
|
|||
|
|
Output:
|
|||
|
|
filterData : 降采样和陷波后的数据
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def preprocessFilter(self, data):
|
|||
|
|
# data = self.myDownSample(data, 4)
|
|||
|
|
# filterData = self.northFilter(data[:8, :])
|
|||
|
|
filterData = self.northFilter(data[:, :])
|
|||
|
|
return filterData
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
fbccaDWMW:分类函数,对输入的脑电信号进行识别,输出决策标签
|
|||
|
|
Inputs:
|
|||
|
|
testdata:脑电数据
|
|||
|
|
referenceData:参考信号
|
|||
|
|
tValue:出决策阈值
|
|||
|
|
Output:
|
|||
|
|
res : 决策标签
|
|||
|
|
rho_new:相关系数
|
|||
|
|
minEps:得到的决策阈值
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
# 动态窗算法主函数
|
|||
|
|
def fbccaDWMW(self, testdata, referenceData, tValue,calculateCount):
|
|||
|
|
t1 = time.time()
|
|||
|
|
# try:
|
|||
|
|
# 初始参数
|
|||
|
|
res = -1
|
|||
|
|
minEps = float("inf")
|
|||
|
|
# 降采样和陷波器处理
|
|||
|
|
northData = self.preprocessFilter(testdata)
|
|||
|
|
newSampleNum = northData.shape[1]
|
|||
|
|
# 数据大于延迟长度,则无法根据后面的规则更新窗口
|
|||
|
|
if newSampleNum > self.winDelayNum:
|
|||
|
|
error('need add window delay time')
|
|||
|
|
|
|||
|
|
# 防止秩小于导联数
|
|||
|
|
if newSampleNum < self.num_chans:
|
|||
|
|
warnings.warn('data shape is [%d %d] need more data' % (newSampleNum, northData.shape[0]))
|
|||
|
|
# 滤波器组频带分解
|
|||
|
|
filterData = self.filterbank(northData)
|
|||
|
|
winMinTime = 0
|
|||
|
|
# 计算每个窗口的结果
|
|||
|
|
for wi in range(0, self.winNum, self.winStep):
|
|||
|
|
# print('dataUseLen:',wi,calculateCount, self.dataUseLen)
|
|||
|
|
if wi == 0:
|
|||
|
|
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
|
|||
|
|
else:
|
|||
|
|
if self.dataUseLen[wi] == 0:
|
|||
|
|
# 判断当前窗是否为新的窗口(因为每一次新的窗口进来时,都会使上一个窗口datauseLen>50)
|
|||
|
|
if self.dataUseLen[wi - self.winStep] > self.winDelayNum*self.winStep:
|
|||
|
|
self.dataUseLen[wi] = newSampleNum
|
|||
|
|
else:
|
|||
|
|
# print('中断: ',wi,calculateCount)
|
|||
|
|
break
|
|||
|
|
else:
|
|||
|
|
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
|
|||
|
|
|
|||
|
|
if self.dataUseLen[wi] > self.winMaxSampleNum:
|
|||
|
|
self.dataUseLen[wi] = newSampleNum
|
|||
|
|
self.Rbuffer[wi, :, :, :] = 0
|
|||
|
|
self.Cxy[wi, :, :, :, :] = 0
|
|||
|
|
Qs1, Qs2, self.Rbuffer[wi, :, :, :] = self.getDataQ(filterData, self.Rbuffer[wi, :, :, :])
|
|||
|
|
si = self.dataUseLen[wi] - newSampleNum
|
|||
|
|
ei = self.dataUseLen[wi]
|
|||
|
|
ref = referenceData[:, :, si:ei]
|
|||
|
|
# 更新协方差
|
|||
|
|
predLabel_new, rho_new, self.Cxy[wi, :, :, :] = self.weightCCA(Qs1, Qs2, ref, self.Cxy[wi, :, :, :])
|
|||
|
|
# 增加限制,数据长度不能太短
|
|||
|
|
if self.dataUseLen[wi] > winMinTime * self.fs:
|
|||
|
|
epsilon = self.costF(rho_new, self.DW_cost_method, C=0)
|
|||
|
|
if epsilon < minEps:
|
|||
|
|
minEps = epsilon
|
|||
|
|
predLabel = predLabel_new
|
|||
|
|
xxx = rho_new
|
|||
|
|
if minEps < tValue:
|
|||
|
|
res = predLabel
|
|||
|
|
|
|||
|
|
if time.time() - t1 > 0.2 and self.winStep < 16:
|
|||
|
|
self.winStep = self.winStep * 2
|
|||
|
|
# print(self.winStep, " ", time.time() - t1)
|
|||
|
|
# if res != -1:
|
|||
|
|
# print('--------------------- ',res,xxx,' --------------------------')
|
|||
|
|
return res
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
# The number of sub-bands in filter bank analysis
|
|||
|
|
fs = 250
|
|||
|
|
num_chans = 8
|
|||
|
|
num_target = 40
|
|||
|
|
num_filterBank = 3
|
|||
|
|
num_harm = 5
|
|||
|
|
stimTime = 0.2 # 多窗口窗长
|
|||
|
|
winNum = 50 # 窗口的个数
|
|||
|
|
trials = 1
|
|||
|
|
step = 50
|
|||
|
|
res = -1
|
|||
|
|
list_freqs = [8.0, 8.2, 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 10.2, 10.4, 10.6, 10.8, 11.0, 11.2, 11.4,
|
|||
|
|
11.6, 11.8, 12.0, 12.2, 12.4, 12.6, 12.8, 13., 13.2, 13.4, 13.6, 13.8, 14., 14.2, 14.4, 14.6, 14.8,
|
|||
|
|
15.0, 15.2, 15.4, 15.6, 15.8]
|
|||
|
|
# 初始化对象
|
|||
|
|
dw = FbccaDw(fs, num_target, num_chans, num_filterBank, num_harm, stimTime, [1.0, 0.3], [8, 8], winNum)
|
|||
|
|
# frequenceband
|
|||
|
|
dw.filterFrequenceBank()
|
|||
|
|
referenceData = dw.reference(list_freqs, int(winNum * stimTime * fs), num_harm)
|
|||
|
|
dw.setNotchFilterPara()
|
|||
|
|
|
|||
|
|
prelabels = np.zeros((1, 40))
|
|||
|
|
coefficient = np.zeros([1, 1])
|
|||
|
|
path = "D:\\工作相关\\项目代码\\SDK封装与测试\\八神BCI脑电信号SDK20230105\\双鹰SDK对比\\offline_data250\\"
|
|||
|
|
for index in range(1, trials + 1):
|
|||
|
|
D = loadmat(os.path.join(path + str(1) + '-warmData.mat'))
|
|||
|
|
warmData = D['warmData']
|
|||
|
|
dw.onlineInit()
|
|||
|
|
dw.filterInit()
|
|||
|
|
dw.warmFilter(warmData.T)
|
|||
|
|
|
|||
|
|
tagget_i = 0
|
|||
|
|
for tagget_i in range(1, step + 1):
|
|||
|
|
D = loadmat(os.path.join(path + str(1) + '-' + str(tagget_i) + '.mat'))
|
|||
|
|
dataSlice = D['dataTemp']
|
|||
|
|
res = dw.fbccaDWMW(dataSlice.T, referenceData, tValue=-0.2)
|
|||
|
|
if res != -1:
|
|||
|
|
break
|
|||
|
|
prelabels[0, index - 1] = res + 1
|
|||
|
|
print(index, '--', res + 1," 计算轮数", tagget_i)
|