530 lines
20 KiB
Python
530 lines
20 KiB
Python
|
||
|
||
# -*- coding: utf-8 -*-
|
||
import os
|
||
import time
|
||
import warnings
|
||
from os import error
|
||
import numpy as np
|
||
import scipy
|
||
from numpy.linalg import linalg
|
||
from scipy.io import loadmat
|
||
from scipy.linalg import qr
|
||
from scipy.signal import filtfilt, lfilter
|
||
# from numpy.linalg import _umath_linalg
|
||
|
||
|
||
class FbccaDw:
|
||
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
||
print('******************************************')
|
||
print('parameter list')
|
||
print('target:', num_target)
|
||
print('number of filter bank:', num_filter)
|
||
print('parameter:', parameter)
|
||
print('width:', width)
|
||
self.phase = 0
|
||
self.bandWidth = width
|
||
self.winNum = winNum
|
||
self.num_harms = num_harms
|
||
self.num_target = num_target
|
||
self.num_chans = num_chans
|
||
self.winTimeDelay = stimTime
|
||
self.fs = fs
|
||
self.winMaxSampleNum = self.winTimeDelay * self.winNum * self.fs
|
||
self.winDelayNum = round(self.winTimeDelay * self.fs)
|
||
self.num_fbs = num_filter
|
||
parameterValue = np.power(np.arange(1, self.num_fbs + 1), -(parameter[0])) + parameter[1]
|
||
self.weightValue = parameterValue / (sum(parameterValue))
|
||
|
||
self.dataUseLen = [0] * self.winNum
|
||
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
|
||
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
|
||
self.rhoNum = 2
|
||
self.notchZh = [0]
|
||
self.filterZf = [0] * self.num_fbs
|
||
self.north_b = []
|
||
self.north_a = []
|
||
self.filterBank_A = []
|
||
self.filterBank_B = []
|
||
self.winStep = 1
|
||
self.DW_cost_method = 'DW11' if method==1 else 'DW1'
|
||
|
||
'''
|
||
filterFrequenceBank:根据刺激频率生成的通带和阻带,用于滤波器组频带分解
|
||
'''
|
||
|
||
def filterFrequenceBank(self):
|
||
# 阻带的最高频率
|
||
lastFrequence = 90
|
||
freqBandWidth = self.bandWidth[1]
|
||
fStep = self.bandWidth[0]
|
||
bandFrequence = np.zeros((5, 4))
|
||
# 第二列频率带
|
||
band = list(range(freqBandWidth, lastFrequence, fStep))
|
||
band[:] = [x - 2 for x in band]
|
||
colValue = np.maximum(np.asmatrix(band), 1)
|
||
bandFrequence[:, 1] = colValue[0, 0:5]
|
||
# 第一列频率带
|
||
bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
|
||
# 第三列频率带
|
||
bandFrequence[:, 2] = lastFrequence + 2
|
||
# 第四列频率带
|
||
bandFrequence[:, 3] = bandFrequence[:, 2] + 10
|
||
# bandFrequence = np.array([[30,33,77,82],
|
||
# [62,68,77,82]])
|
||
for idx_fb in range(self.num_fbs):
|
||
Nq = self.fs / 2
|
||
Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
|
||
Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
|
||
[N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
|
||
40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
|
||
[B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
|
||
self.filterBank_A.append(A)
|
||
self.filterBank_B.append(B)
|
||
# def filterFrequenceBank(self):
|
||
# # 阻带的最高频率
|
||
# lastFrequence = 90
|
||
# freqBandWidth = self.bandWidth[1]
|
||
# fStep = self.bandWidth[0]
|
||
# bandFrequence = np.zeros((5, 4))
|
||
# # 第二列频率带
|
||
# band = list(range(freqBandWidth, lastFrequence, fStep))
|
||
# band[:] = [x - 2 for x in band]
|
||
# colValue = np.maximum(np.asmatrix(band), 1)
|
||
# bandFrequence[:, 1] = colValue[0, 0:5]
|
||
# # 第一列频率带
|
||
# bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
|
||
# # 第三列频率带
|
||
# bandFrequence[:, 2] = lastFrequence + 2
|
||
# # 第四列频率带
|
||
# bandFrequence[:, 3] = bandFrequence[:, 2] + 10
|
||
# for idx_fb in range(self.num_fbs):
|
||
# Nq = self.fs / 2
|
||
# Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
|
||
# Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
|
||
# [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
|
||
# 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
|
||
# [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
|
||
# self.filterBank_A.append(A)
|
||
# self.filterBank_B.append(B)
|
||
|
||
'''
|
||
Filter bank analysis
|
||
Input:
|
||
eeg : Input eeg data (# of targets, # of channels, Data length [sample])
|
||
Output:
|
||
filterData : Generated filter Data
|
||
'''
|
||
|
||
def filterbank(self, eeg):
|
||
filterData = np.zeros((self.num_fbs, eeg.shape[1], eeg.shape[0]))
|
||
for filterIndex in range(self.num_fbs):
|
||
if np.all(self.filterZf[filterIndex] == 0):
|
||
zi = np.zeros(
|
||
[max(len(self.filterBank_A[filterIndex]), len(self.filterBank_B[filterIndex])) - 1, self.num_chans])
|
||
_, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex],
|
||
eeg, zi=zi.T)
|
||
Data = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], eeg)
|
||
else:
|
||
Data, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex],
|
||
self.filterBank_A[filterIndex], eeg,
|
||
zi=self.filterZf[filterIndex])
|
||
filterData[filterIndex, :, :] = Data.T
|
||
return filterData
|
||
|
||
'''
|
||
process
|
||
矩阵的白化和QR正则化分解,降低矩阵的维度,加速计算时间
|
||
Input:
|
||
data : 输入的二维脑电信号
|
||
Output:
|
||
Q : 降维后的矩阵
|
||
rankQ :正则矩阵的秩
|
||
'''
|
||
|
||
def process(self, data):
|
||
# 白化操作
|
||
meanValue = np.asmatrix(data.mean(axis=1))
|
||
meanData = np.repeat(meanValue, data.shape[1], axis=1)
|
||
whiteTemp = data - meanData
|
||
# QR 分解
|
||
rankWhiteTemp = whiteTemp.shape[0]
|
||
whiteTemp = np.transpose(whiteTemp)
|
||
Q, R = qr(whiteTemp.A, mode='economic')
|
||
# 计算矩阵的秩
|
||
rankQ = linalg.matrix_rank(R)
|
||
if rankQ == 0:
|
||
raise ValueError('stats:canoncorr:badData')
|
||
elif rankQ <= rankWhiteTemp:
|
||
# warnings.warn('stats:canoncorr:NotFullRank')
|
||
Q = Q[:, 0:rankQ]
|
||
return Q, rankQ
|
||
|
||
'''
|
||
reference
|
||
Input:
|
||
listFreqs : 刺激频率列表
|
||
numberSmples : 用于分类的脑电信号采样点个数
|
||
num_harms : 谐波数
|
||
Output:
|
||
y_ref : 生成的参考信号 (刺激目标数, 2 * 谐波数, 数据长度/采样点数)
|
||
'''
|
||
|
||
def reference(self, listFreqs, numberSmples, num_harms):
|
||
numberFrequence = len(listFreqs)
|
||
timeIndex = np.arange(1, numberSmples + 1) / self.fs # time index
|
||
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
|
||
for frequenceIndex in range(numberFrequence):
|
||
temp = []
|
||
for harmIndex in range(1, num_harms + 1):
|
||
stimFrequence = listFreqs[frequenceIndex] # in HZ
|
||
# Sin and Cos
|
||
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
|
||
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
|
||
referenceTemp = np.asmatrix(temp)
|
||
# 白化操作和QR分解
|
||
Q, rankQ = self.process(referenceTemp)
|
||
referenceData[frequenceIndex] = np.transpose(Q)
|
||
return referenceData
|
||
|
||
'''
|
||
setNorthFilterPara
|
||
陷波器的参数初始化
|
||
self.north_b, self.north_a : 陷波器的参数设计
|
||
'''
|
||
|
||
def setNotchFilterPara(self):
|
||
# notchFilterNum = 3
|
||
# northFreq = 50
|
||
# bwDen = 35
|
||
# wo = northFreq / (self.fs / 2)
|
||
# bw = wo / bwDen
|
||
# self.north_b, self.north_a = iirnotch(wo, Q=35) # self.north_b, self.north_a = iircomb(northFreq, bwDen, 'notch')
|
||
# # n倍零极点,相当于重复滤波n次
|
||
# if notchFilterNum > 1:
|
||
# z, p, k = tf2zpk(self.north_b, self.north_a)
|
||
# zNew = np.repeat(z, notchFilterNum, axis=0)
|
||
# zNew[1], zNew[4] = zNew[4], zNew[1]
|
||
# pNew = np.repeat(p, notchFilterNum, axis=0)
|
||
# pNew[1], pNew[4] = pNew[4], pNew[1]
|
||
# kNew = np.power(k, notchFilterNum)
|
||
# self.north_b, self.north_a = zpk2tf(zNew, pNew, kNew)
|
||
self.north_b = [0.94801603944125245604368501517456, -1.7577184027642647201616910024313,
|
||
3.9303778338832491279219993884908, -3.7392330345967859095424046245171,
|
||
3.9303778338832482397435796883656, -1.7577184027642638319832713023061,
|
||
0.94801603944125156786526531504933]
|
||
|
||
self.north_a = [1, -1.8214007435820627200939725298667, 4.0000101767406484043476666556671,
|
||
-3.7380998614928691026193519064691, 3.8589119784285759173769747576443,
|
||
-1.6951692350503837491970671180752, 0.89786559147978006745205448169145]
|
||
|
||
'''
|
||
northFilter
|
||
进行信号的50hz陷波处理
|
||
Input:
|
||
data :输入脑电数据
|
||
Output:
|
||
dataFiltered : 陷波处理后的脑电数据
|
||
'''
|
||
|
||
def northFilter(self, data):
|
||
try:
|
||
if np.all(self.notchZh[0] == 0):
|
||
zi = np.zeros([max(len(self.north_a), len(self.north_b)) - 1, self.num_chans])
|
||
_, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=zi.T)
|
||
dataFiltered = lfilter(self.north_b, self.north_a, data)
|
||
else:
|
||
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
||
return np.asmatrix(dataFiltered)
|
||
except Exception:
|
||
print(Exception)
|
||
|
||
'''
|
||
getDataQ
|
||
Inputs:
|
||
data:脑电数据
|
||
Rbuffer:待更新的中间系数
|
||
Output:
|
||
Qs1 : 脑电特征1
|
||
Qs2 : 脑电特征2
|
||
Rbuffer : 单窗口更新后的系数
|
||
|
||
'''
|
||
|
||
def getDataQ(self, data, Rbuffer):
|
||
Qs1 = [0] * self.num_fbs
|
||
Qs2 = [0] * self.num_fbs
|
||
nulldata = np.zeros([self.num_chans, self.num_chans])
|
||
Rnum = self.num_chans
|
||
for fb_num in range(self.num_fbs):
|
||
fb_data = np.squeeze(data[fb_num, :, :])
|
||
if np.all(Rbuffer[fb_num] == 0):
|
||
whiteTemp = fb_data
|
||
Q, R = qr(whiteTemp, mode='economic')
|
||
Qs1[fb_num] = nulldata
|
||
Qs2[fb_num] = Q
|
||
Rbuffer[fb_num] = R
|
||
else:
|
||
whiteTemp = np.concatenate((Rbuffer[fb_num], fb_data), axis=0)
|
||
Q, R = qr(whiteTemp, mode='economic')
|
||
Qs1[fb_num] = Q[0:Rnum, :]
|
||
Qs2[fb_num] = Q[Rnum:, :]
|
||
Rbuffer[fb_num] = R
|
||
return Qs1, Qs2, Rbuffer
|
||
|
||
'''
|
||
myCCA:根据脑电特征和参考信号计算相关系数
|
||
Inputs:
|
||
dataQ:脑电特征
|
||
Qc2y:参考信号
|
||
d : 相关系数取值数
|
||
Output:
|
||
rho : 相关系数
|
||
|
||
'''
|
||
|
||
def myCCA(self, dataQ, Qc2y, d):
|
||
if len(Qc2y) == 0:
|
||
Cov = dataQ
|
||
else:
|
||
Cov = np.dot(Qc2y, dataQ)
|
||
# U, S, V = scipy.linalg.svd(Cov, 0)
|
||
# rho = np.minimum(np.maximum(np.diag(S[0: d]).T, 0), 1)
|
||
# gufunc = _umath_linalg.svd_n
|
||
# rho = gufunc(Cov)
|
||
rho = np.linalg.svd(Cov, compute_uv=False)
|
||
return rho[0:d]
|
||
|
||
'''
|
||
weightCCA:计算分类标签
|
||
Inputs:
|
||
Qs1:脑电特征1
|
||
Qs2:脑电特征2
|
||
ref : 正余弦参考信号
|
||
Cxy : 协方差中间参数
|
||
Output:
|
||
result : 分类标签
|
||
rho : 相关系数
|
||
Cxy : 更新后的协方差中间参数
|
||
'''
|
||
|
||
def weightCCA(self, Qs1, Qs2, ref, Cxy):
|
||
rMax = np.zeros([self.num_fbs, self.num_target])
|
||
for fi in range(self.num_fbs):
|
||
for si in range(self.num_target):
|
||
Qc2y = np.squeeze(ref[si, :, :])
|
||
# 更新协方差矩阵
|
||
if np.all(Cxy[fi][si] == 0):
|
||
Cxy[fi, si] = np.dot(Qc2y, Qs2[fi])
|
||
else:
|
||
Cxy[fi, si] = np.dot(Cxy[fi, si], Qs1[fi]) + np.dot(Qc2y, Qs2[fi])
|
||
r = self.myCCA(Cxy[fi, si], [], self.rhoNum)
|
||
rMax[fi, si] = r[0]
|
||
rho = np.dot(self.weightValue, np.power(rMax, 2)) # weighted sum of r from all different filter banks' result
|
||
result = np.argmax(rho)
|
||
return result, rho, Cxy
|
||
|
||
'''
|
||
costF:损失函数,根据计算的相关系数,生成决策值,用于和阈值进行比较
|
||
Inputs:
|
||
rho:相关系数
|
||
method:相关系数计算参数
|
||
C : 参数
|
||
Output:
|
||
decideValue : 决策阈值
|
||
'''
|
||
|
||
def costF(self, rho, method, C):
|
||
rho = rho.tolist()
|
||
rho.sort(reverse=True)
|
||
if method == 'DW1':
|
||
decideValue = (rho[0] - rho[1]) / (sum(rho) - self.num_target * np.log(sum(np.exp(rho))))
|
||
elif method == 'DW11':
|
||
decideValue = -(rho[0] - rho[1])
|
||
elif method == 'DW2':
|
||
decideValue = (rho[0] - C) / (rho[1] - rho[0])
|
||
return decideValue
|
||
|
||
'''
|
||
onlineInit:将窗口长度,相位值、中间参数初始化
|
||
'''
|
||
|
||
def onlineInit(self):
|
||
self.dataUseLen = [0] * self.winNum
|
||
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
|
||
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
|
||
self.phase = 0
|
||
|
||
'''
|
||
filterInit:重置陷波器和滤波器的滤波参数
|
||
'''
|
||
|
||
def filterInit(self):
|
||
self.notchZh = [0]
|
||
self.filterZf = [0] * self.num_fbs
|
||
|
||
'''
|
||
warmFilter:预热滤波器,去刺激前的4S数据对陷波器和滤波器参数进行初始化迭代,去除过渡带的效果
|
||
Inputs:
|
||
data:预处理脑电数据
|
||
'''
|
||
|
||
def warmFilter(self, data):
|
||
# 降采样在采集前完成
|
||
temp = self.preprocessFilter(data) #预热陷波滤波器
|
||
# 滤波器组频带分解
|
||
filterData = self.filterbank(temp) #预热滤波器组
|
||
|
||
'''
|
||
myDownSample:数据降采样
|
||
Inputs:
|
||
data:脑电数据
|
||
n:降采样的倍数
|
||
Output:
|
||
eegData2 : 降采样后的数据
|
||
'''
|
||
|
||
def myDownSample(self, data, n):
|
||
data = data[:8, self.phase:]
|
||
dataNum = data.shape[1]
|
||
remainNum = (dataNum - 1) % n
|
||
self.phase = n - 1 - remainNum
|
||
dataDowmSample = []
|
||
for value in data:
|
||
value = value[0:value.size:n]
|
||
dataDowmSample.append(value)
|
||
eegData2 = np.array(dataDowmSample).reshape([8, int(np.round(data.shape[1] / n))])
|
||
return eegData2
|
||
|
||
'''
|
||
preprocessFilter:预处理,调用函数降采样和陷波处理
|
||
Inputs:
|
||
data:脑电数据
|
||
Output:
|
||
filterData : 降采样和陷波后的数据
|
||
'''
|
||
|
||
def preprocessFilter(self, data):
|
||
# data = self.myDownSample(data, 4)
|
||
# filterData = self.northFilter(data[:8, :])
|
||
filterData = self.northFilter(data[:, :])
|
||
return filterData
|
||
|
||
'''
|
||
fbccaDWMW:分类函数,对输入的脑电信号进行识别,输出决策标签
|
||
Inputs:
|
||
testdata:脑电数据
|
||
referenceData:参考信号
|
||
tValue:出决策阈值
|
||
Output:
|
||
res : 决策标签
|
||
rho_new:相关系数
|
||
minEps:得到的决策阈值
|
||
'''
|
||
|
||
# 动态窗算法主函数
|
||
def fbccaDWMW(self, testdata, referenceData, tValue,calculateCount):
|
||
t1 = time.time()
|
||
# try:
|
||
# 初始参数
|
||
res = -1
|
||
minEps = float("inf")
|
||
# 降采样和陷波器处理
|
||
northData = self.preprocessFilter(testdata)
|
||
newSampleNum = northData.shape[1]
|
||
# 数据大于延迟长度,则无法根据后面的规则更新窗口
|
||
if newSampleNum > self.winDelayNum:
|
||
error('need add window delay time')
|
||
|
||
# 防止秩小于导联数
|
||
if newSampleNum < self.num_chans:
|
||
warnings.warn('data shape is [%d %d] need more data' % (newSampleNum, northData.shape[0]))
|
||
# 滤波器组频带分解
|
||
filterData = self.filterbank(northData)
|
||
winMinTime = 0
|
||
# 计算每个窗口的结果
|
||
for wi in range(0, self.winNum, self.winStep):
|
||
# print('dataUseLen:',wi,calculateCount, self.dataUseLen)
|
||
if wi == 0:
|
||
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
|
||
else:
|
||
if self.dataUseLen[wi] == 0:
|
||
# 判断当前窗是否为新的窗口(因为每一次新的窗口进来时,都会使上一个窗口datauseLen>50)
|
||
if self.dataUseLen[wi - self.winStep] > self.winDelayNum*self.winStep:
|
||
self.dataUseLen[wi] = newSampleNum
|
||
else:
|
||
# print('中断: ',wi,calculateCount)
|
||
break
|
||
else:
|
||
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
|
||
|
||
if self.dataUseLen[wi] > self.winMaxSampleNum:
|
||
self.dataUseLen[wi] = newSampleNum
|
||
self.Rbuffer[wi, :, :, :] = 0
|
||
self.Cxy[wi, :, :, :, :] = 0
|
||
Qs1, Qs2, self.Rbuffer[wi, :, :, :] = self.getDataQ(filterData, self.Rbuffer[wi, :, :, :])
|
||
si = self.dataUseLen[wi] - newSampleNum
|
||
ei = self.dataUseLen[wi]
|
||
ref = referenceData[:, :, si:ei]
|
||
# 更新协方差
|
||
predLabel_new, rho_new, self.Cxy[wi, :, :, :] = self.weightCCA(Qs1, Qs2, ref, self.Cxy[wi, :, :, :])
|
||
# 增加限制,数据长度不能太短
|
||
if self.dataUseLen[wi] > winMinTime * self.fs:
|
||
epsilon = self.costF(rho_new, self.DW_cost_method, C=0)
|
||
if epsilon < minEps:
|
||
minEps = epsilon
|
||
predLabel = predLabel_new
|
||
xxx = rho_new
|
||
if minEps < tValue:
|
||
res = predLabel
|
||
|
||
if time.time() - t1 > 0.2 and self.winStep < 16:
|
||
self.winStep = self.winStep * 2
|
||
# print(self.winStep, " ", time.time() - t1)
|
||
# if res != -1:
|
||
# print('--------------------- ',res,xxx,' --------------------------')
|
||
return res
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# The number of sub-bands in filter bank analysis
|
||
fs = 250
|
||
num_chans = 8
|
||
num_target = 40
|
||
num_filterBank = 3
|
||
num_harm = 5
|
||
stimTime = 0.2 # 多窗口窗长
|
||
winNum = 50 # 窗口的个数
|
||
trials = 1
|
||
step = 50
|
||
res = -1
|
||
list_freqs = [8.0, 8.2, 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 10.2, 10.4, 10.6, 10.8, 11.0, 11.2, 11.4,
|
||
11.6, 11.8, 12.0, 12.2, 12.4, 12.6, 12.8, 13., 13.2, 13.4, 13.6, 13.8, 14., 14.2, 14.4, 14.6, 14.8,
|
||
15.0, 15.2, 15.4, 15.6, 15.8]
|
||
# 初始化对象
|
||
dw = FbccaDw(fs, num_target, num_chans, num_filterBank, num_harm, stimTime, [1.0, 0.3], [8, 8], winNum)
|
||
# frequenceband
|
||
dw.filterFrequenceBank()
|
||
referenceData = dw.reference(list_freqs, int(winNum * stimTime * fs), num_harm)
|
||
dw.setNotchFilterPara()
|
||
|
||
prelabels = np.zeros((1, 40))
|
||
coefficient = np.zeros([1, 1])
|
||
path = "D:\\工作相关\\项目代码\\SDK封装与测试\\八神BCI脑电信号SDK20230105\\双鹰SDK对比\\offline_data250\\"
|
||
for index in range(1, trials + 1):
|
||
D = loadmat(os.path.join(path + str(1) + '-warmData.mat'))
|
||
warmData = D['warmData']
|
||
dw.onlineInit()
|
||
dw.filterInit()
|
||
dw.warmFilter(warmData.T)
|
||
|
||
tagget_i = 0
|
||
for tagget_i in range(1, step + 1):
|
||
D = loadmat(os.path.join(path + str(1) + '-' + str(tagget_i) + '.mat'))
|
||
dataSlice = D['dataTemp']
|
||
res = dw.fbccaDWMW(dataSlice.T, referenceData, tValue=-0.2)
|
||
if res != -1:
|
||
break
|
||
prelabels[0, index - 1] = res + 1
|
||
print(index, '--', res + 1," 计算轮数", tagget_i)
|