add buffer

This commit is contained in:
2026-06-06 14:40:07 +08:00
parent 868ff30238
commit 2d190d6431
5 changed files with 414 additions and 266 deletions

View File

@@ -1,4 +1,6 @@
import ast
import os
import sys
import threading
from datetime import datetime
import multiprocessing as mp
@@ -8,7 +10,7 @@ import torch
from queue import Empty
from scipy import signal
from torch.autograd import Variable
from Device.SunnyLinker import SunnyLinker64
# from Device.SunnyLinker import SunnyLinker64
from SSMVEP.algorithm.tdca import TDCA
from SSMVEP.algorithm.base import generate_cca_references
from concentration.algorithm.calculate_focus import Calculate
@@ -17,49 +19,70 @@ from Zmq.zmqServer import zmqServer
from Zmq.zmqClient import zmqClient
from MI.Algorithm.conformer_2class import onlineTrain
from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log
from SSVEP.dwfbcca import FbccaDw
from Tools.plot_MI_EEG import plotMain
# from Tools.plot_MI_EEG import plotMain
from collections import deque
class Decoder_main(threading.Thread, device_type):
def __init__(self, device_type=None):
def get_root_path():
"""
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
"""
if getattr(sys, 'frozen', False):
# 打包后:返回 exe 所在目录
return os.path.dirname(sys.executable)
else:
# 开发时:返回 py 文件所在目录
return os.path.dirname(os.path.abspath(__file__))
MODEL_FOLDER = "online_Models"
class Decoder_main(threading.Thread):
def __init__(self, device_info=None):
threading.Thread.__init__(self)
self.device_info = {
'sample_rate': device_info['sample_rate'],
'frame_points': device_info['frame_points'],
'channel_nums': device_info['channel_nums'],
'channel_names': device_info['channel_names'],
'channel_index': device_info['channel_index'],
}
self.Runing=True
self.decoder = None
self.fs = 250 # 采样率
self.energy = 0 # 电量
self.status_code = 0 # 与采集设备通信的状态码0为异常1为正常
self.decoder_class = None #解码器类别
# 与采集设备通信的状态码0为异常1为正常
# self.status_code = 0
# self.device_info['sample_rate'] = 250 # 采样率
# self.energy = 0 # 电量
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
self.device_info = {
'device_type': None,
'sample_rate': None,
'channel_num': None,
}
def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
_device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
_device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
_upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
_upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
if self.DeviceType == 1:
self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp')
self.thread_data_server.host = _device_host
self.thread_data_server.port = _device_port
self.thread_data_server.toUv = True
self.thread_data_server.start()
self.zmqServer = zmqServer()
self.zmqServer = zmqServer(device_info=self.device_info)
self.zmqServer.start()
self.zmqClient = zmqClient(_upper_host, _upper_port)
self.zmqClient.set_zmq_server(self.zmqServer)
self.zmqClient.connect()
# self.zmqClient = zmqClient(_upper_host, _upper_port)
# self.zmqClient.set_zmq_server(self.zmqServer)
# self.zmqClient.connect()
# def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
# self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
# _device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
# _device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
# _upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
# _upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
# if self.DeviceType == 1:
# self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.device_info['sample_rate'], 64, method='tcp')
# self.thread_data_server.host = _device_host
# self.thread_data_server.port = _device_port
# self.thread_data_server.toUv = True
# self.thread_data_server.start()
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples)
@@ -76,26 +99,25 @@ class Decoder_main(threading.Thread, device_type):
self.decoder_class = decoder_class
if decoder_class == 'ssvep' or decoder_class == 'pvs':
self.n_chan = 8
self.thread_data_server.interval_inited = False
# self.thread_data_server.interval_inited = False
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
self.ListFreq = self.zmqServer.targetFreqs
self.num_target = len(self.ListFreq)
if self.num_target == 0:
return
# 初始化对象 二代算法
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5,
self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
# frequence band
self.dw.filterFrequenceBank()
self.dw.setNotchFilterPara()
self.calculateCount = 0
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs),
5)
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
self.dw.filterInit()
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
elif decoder_class == 'ssmvep':
self.thread_data_server.interval_init(decoder_class)
self.zmqServer.interval_init(decoder_class)
self.n_chan = 8
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
@@ -104,12 +126,12 @@ class Decoder_main(threading.Thread, device_type):
self.list_freqs = np.array([8, 9]) # 刺激频率
self.list_phase = np.array([0, 0]) # 相位
self.tdca = TDCA(padding_len=5, n_components=1)
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length,
self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
phases=self.list_phase, n_harmonics=5)
self.parameter_init(5,45)
elif decoder_class == 'mi' or decoder_class == 'ma':
self.thread_data_server.interval_init(decoder_class)
self.zmqServer.interval_init(decoder_class)
self.n_chan = 21
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
@@ -124,7 +146,7 @@ class Decoder_main(threading.Thread, device_type):
# self.win_len = 10
# self.win_step = 1
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len)
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
# self.interval_epoch = [0, 1]
# self.parameter_init(2, 40)
# # self.eegQueue moved to Calculate class
@@ -136,8 +158,8 @@ class Decoder_main(threading.Thread, device_type):
# self.total_samples = 0 # 总采样点数
# self.window_ms = 600 # 检测窗口大小 (ms)
# self.step_ms = 100 # 滑动步长 (ms)
# self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点
# self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点
# self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
# self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
# self.buffer_size = self.window_samples + self.step_samples * 5
# self.fp1_buffer = deque(maxlen=self.buffer_size)
# self.fp2_buffer = deque(maxlen=self.buffer_size)
@@ -151,11 +173,11 @@ class Decoder_main(threading.Thread, device_type):
# self.double_blink_events = [] # 连续眨眼事件记录
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
# self.blink_events = []
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band')
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
def parameter_init(self,bandPass_low,bandPass_high):
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
self.trainData = [] #训练数据
self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据
@@ -163,10 +185,10 @@ class Decoder_main(threading.Thread, device_type):
self.currentLabel = -1 #刺激界面当前显示的训练标签
self.train_started = False #是否开始训练模型
self.load_model = False # 调用模型是否完成的标志
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波250是采样率30是质量因子
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波250是采样率30是质量因子
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
filePath = './online_Models/'
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
self.modelPath = ''.join([filePath, fileName, '.pth'])
self.mp_data_queue = mp.Queue() #多进程传参队列
self.mp_result_queue = mp.Queue() #多进程结果队列
@@ -192,54 +214,55 @@ class Decoder_main(threading.Thread, device_type):
# 同步信息
if self.zmqServer.state_mode == 'sync':
self.zmqClient.send_to_all('sync', self.zmqClient.state)
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
self.zmqServer.state_mode = 'rest'
# 状态异常,报告上位机
if self.status_code != self.thread_data_server.status_code:
self.status_code = self.thread_data_server.status_code
self.zmqClient.send_to_all('status_code', int(self.status_code))
print('status code')
# 返回电量
if self.energy != self.thread_data_server.energy:
self.energy = self.thread_data_server.energy
self.zmqClient.send_to_all('energy', int(self.energy))
print('energy')
# # 状态异常,报告上位机
# if self.status_code != self.thread_data_server.status_code:
# self.status_code = self.thread_data_server.status_code
# self.zmqClient.send_to_all('status_code', int(self.status_code))
# print('status code')
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
self.thread_data_server.Impedance(True)
print('Impedance')
self.zmqServer.open_Impedance = -1
elif self.zmqServer.open_Impedance == False:
self.thread_data_server.Impedance(False)
self.zmqServer.open_Impedance = -1
# # 返回电量
# if self.energy != self.thread_data_server.energy:
# self.energy = self.thread_data_server.energy
# self.zmqClient.send_to_all('energy', int(self.energy))
# print('energy')
if self.zmqServer.get_Impedance: # 返回阻抗值
# print(self.zmqServer.get_Impedance)
# print(self.thread_data_server.GetDataLenCount())
if self.thread_data_server.GetDataLenCount() > 250:
Impe_data = self.thread_data_server.getData(250)
# 计算阻抗
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
self.zmqClient.send_to_all('impedance', imps.tolist())
else:
pass
if self.zmqServer.getReport: #返回训练报告内容
self.zmqServer.getReport = False
allData = np.array(self.plotData)
allLabel = np.array(self.plotLabel) + 1
nTrials = min(len(allLabel),len(allData))
if nTrials < 30:
self.zmqClient.send_to_all('miReport',0)
else:
allData = allData[:nTrials]
allLabel = allLabel[:nTrials]
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
compare_names = ['C3', 'CZ', 'C4']
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
fs=self.fs)
self.zmqClient.send_to_all('miReport',miReport)
# if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
# self.thread_data_server.Impedance(True)
# print('Impedance')
# self.zmqServer.open_Impedance = -1
# elif self.zmqServer.open_Impedance == False:
# self.thread_data_server.Impedance(False)
# self.zmqServer.open_Impedance = -1
# if self.zmqServer.get_Impedance: # 返回阻抗值
# # print(self.zmqServer.get_Impedance)
# # print(self.thread_data_server.GetDataLenCount())
# if self.thread_data_server.GetDataLenCount() > 250:
# Impe_data = self.thread_data_server.getData(250)
# # 计算阻抗
# imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
# self.zmqClient.send_to_all('impedance', imps.tolist())
# else:
# pass
# if self.zmqServer.getReport: #返回训练报告内容
# self.zmqServer.getReport = False
# allData = np.array(self.plotData)
# allLabel = np.array(self.plotLabel) + 1
# nTrials = min(len(allLabel),len(allData))
# if nTrials < 30:
# self.zmqClient.send_to_all('miReport',0)
# else:
# allData = allData[:nTrials]
# allLabel = allLabel[:nTrials]
# ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
# 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
# compare_names = ['C3', 'CZ', 'C4']
# miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
# fs=self.device_info['sample_rate'])
# self.zmqClient.send_to_all('miReport',miReport)
# --- 取数优先:先执行 decoder消费环形缓冲再处理 plot/report 等重负载 ---
@@ -250,34 +273,33 @@ class Decoder_main(threading.Thread, device_type):
self.decoder_SSMVEP()
elif self.decoder_class == 'mi':
self.decoder_MI()
elif self.decoder_class == 'concentration':
self.decoder_concentration()
elif self.decoder_class == 'blink':
self.decoder_blink()
else:
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005)
continue;
self.thread_data_server.getData(25)
# elif self.decoder_class == 'concentration':
# self.decoder_concentration()
# elif self.decoder_class == 'blink':
# self.decoder_blink()
# else:
# self.
# # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
# # if self.thread_data_server.GetDataLenCount() < 25:
# # time.sleep(0.005)
# # continue;
# # self.thread_data_server.getData(25)
except Exception as e:
print(f"Decoder Loop Error: {e}")
import traceback
traceback.print_exc()
algo_log(f"Decoder Loop Error: {e}")
time.sleep(0.1) # Prevent CPU spin if error is persistent
def decoder_SSVEP(self):
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
self.decodingSteps = 1
self.thread_data_server.ResetAll()
self.zmqServer.paradigmBuffer.ResetAllPara()
print('启动预测')
if self.thread_data_server.GetDataLenCount() < 50:
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
time.sleep(0.005)
return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return
data = self.thread_data_server.getDataViaSSVEP(50)
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
@@ -293,7 +315,7 @@ class Decoder_main(threading.Thread, device_type):
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息
self.zmqClient.send_to_all('result', int(choosenNum))
self.zmqServer.broadcast_message('result', int(choosenNum))
self.decodingSteps = 0
print('发送给界面完成。')
@@ -312,19 +334,19 @@ class Decoder_main(threading.Thread, device_type):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('模型训练完成', formatted_time)
self.load_model = True
self.zmqClient.send_to_all('paradigm', 1)
self.zmqServer.broadcast_message('paradigm', 1)
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.train_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
@@ -348,7 +370,7 @@ class Decoder_main(threading.Thread, device_type):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time)
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
@@ -360,8 +382,8 @@ class Decoder_main(threading.Thread, device_type):
self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
pad_eeg_test = np.zeros(
(data.shape[0], int((self.sample_length + 0.1) * self.fs)))
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0]
@@ -371,7 +393,7 @@ class Decoder_main(threading.Thread, device_type):
print('发送给界面完成。')
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.thread_data_server.getData(25)
@@ -419,12 +441,12 @@ class Decoder_main(threading.Thread, device_type):
if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
@@ -448,7 +470,7 @@ class Decoder_main(threading.Thread, device_type):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time)
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
@@ -477,128 +499,128 @@ class Decoder_main(threading.Thread, device_type):
print(f'发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.thread_data_server.getData(25)
def decoder_concentration(self):
if self.zmqServer.state_mode == 'predict':
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
self.thread_data_server.ResetAll()
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动专注力预测 ', formatted_time)
if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果
time.sleep(0.005)
return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return
data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据
result = self.calculate.queueOpt(data)
if result is not None:
self.zmqClient.send_to_all('result', int(result))
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.thread_data_server.getData(25)
# def decoder_concentration(self):
# if self.zmqServer.state_mode == 'predict':
# if self.zmqServer.StartDecode:
# self.zmqServer.StartDecode = False
# self.thread_data_server.ResetAll()
# now = datetime.now()
# formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
# print('启动专注力预测 ', formatted_time)
# if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
# time.sleep(0.005)
# return
# if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
# return
# data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
# result = self.calculate.queueOpt(data)
# if result is not None:
# self.zmqClient.send_to_all('result', int(result))
# else: # 休息状态
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
# if self.thread_data_server.GetDataLenCount() < 25:
# time.sleep(0.005)
# return
# self.thread_data_server.getData(25)
#### Blink detection #####
def check_double_blink(self, current_time):
"""
检查是否检测到连续两次眨眼
@param current_time: 当前眨眼时间戳
@return: True表示检测到连续两次眨眼
"""
if len(self.blink_timestamps) < 2:
return False
# def check_double_blink(self, current_time):
# """
# 检查是否检测到连续两次眨眼
# @param current_time: 当前眨眼时间戳
# @return: True表示检测到连续两次眨眼
# """
# if len(self.blink_timestamps) < 2:
# return False
# 检查是否在去抖期内
if self.last_double_blink_time > 0:
time_since_last_double_blink = current_time - self.last_double_blink_time
if time_since_last_double_blink < self.double_blink_jitter:
return False # 在去抖期内,忽略连续眨眼检测
last_time = self.blink_timestamps[-1] # 当前眨眼
prev_time = self.blink_timestamps[-2] # 上次眨眼
# # 检查是否在去抖期内
# if self.last_double_blink_time > 0:
# time_since_last_double_blink = current_time - self.last_double_blink_time
# if time_since_last_double_blink < self.double_blink_jitter:
# return False # 在去抖期内,忽略连续眨眼检测
# last_time = self.blink_timestamps[-1] # 当前眨眼
# prev_time = self.blink_timestamps[-2] # 上次眨眼
interval = last_time - prev_time
if interval <= self.double_blink_interval:
return True
# interval = last_time - prev_time
# if interval <= self.double_blink_interval:
# return True
return False
# return False
def process_blink_detection(self):
"""
在缓冲区数据上执行,单次眨眼检测
"""
if len(self.fp1_buffer) < self.window_samples:
return
# def process_blink_detection(self):
# """
# 在缓冲区数据上执行,单次眨眼检测
# """
# if len(self.fp1_buffer) < self.window_samples:
# return
fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
# 计算FP1和FP2的平均
fp12_mean = (fp1_data + fp2_data) / 2.0
# 带通滤波
try:
fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
except Exception as e:
print(f"Filter error: {e}")
return
F = np.diff(fp12_filtered)
if len(F) < 3:
return
b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax)
# fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
# fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
# # 计算FP1和FP2的平均
# fp12_mean = (fp1_data + fp2_data) / 2.0
# # 带通滤波
# try:
# fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
# except Exception as e:
# print(f"Filter error: {e}")
# return
# F = np.diff(fp12_filtered)
# if len(F) < 3:
# return
# b, d, e = blink_detection(F, self.device_info['sample_rate'], self.Dmin, self.Dmax, self.EMin, self.EMax)
if b == 1:
samples_since_last = self.total_samples - self.last_blink_time
time_since_last_ms = (samples_since_last / self.fs) * 1000
if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
self.blink_count += 1
self.last_blink_time = self.total_samples
current_time = time.time()
self.blink_timestamps.append(current_time)
blink_event = {
'count': self.blink_count,
'time': current_time,
'sample_index': self.total_samples,
'duration_ms': d,
'energy': e
}
self.blink_events.append(blink_event)
self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
if self.check_double_blink(current_time):
self.double_blink_count += 1
interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
double_blink_event = {
'double_blink_count': self.double_blink_count,
'blink1_time': self.blink_timestamps[-2],
'blink2_time': self.blink_timestamps[-1],
'interval': interval
}
self.double_blink_events.append(double_blink_event)
self.last_double_blink_time = current_time
self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
# if b == 1:
# samples_since_last = self.total_samples - self.last_blink_time
# time_since_last_ms = (samples_since_last / self.device_info['sample_rate']) * 1000
# if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
# self.blink_count += 1
# self.last_blink_time = self.total_samples
# current_time = time.time()
# self.blink_timestamps.append(current_time)
# blink_event = {
# 'count': self.blink_count,
# 'time': current_time,
# 'sample_index': self.total_samples,
# 'duration_ms': d,
# 'energy': e
# }
# self.blink_events.append(blink_event)
# self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
# if self.check_double_blink(current_time):
# self.double_blink_count += 1
# interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
# double_blink_event = {
# 'double_blink_count': self.double_blink_count,
# 'blink1_time': self.blink_timestamps[-2],
# 'blink2_time': self.blink_timestamps[-1],
# 'interval': interval
# }
# self.double_blink_events.append(double_blink_event)
# self.last_double_blink_time = current_time
# self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
def decoder_blink(self):
if self.thread_data_server.GetDataLenCount() < 50:
time.sleep(0.005)
return
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
data = self.thread_data_server.get_blinkData(50)
fp1_data = data[0, :] # ch1 (相当于FP1)
fp2_data = data[1, :] # ch2 (相当于FP2)
for i in range(len(fp1_data)):
self.fp1_buffer.append(fp1_data[i])
self.fp2_buffer.append(fp2_data[i])
self.total_samples += 1
self.sample_counter += 1
# def decoder_blink(self):
# if self.thread_data_server.GetDataLenCount() < 50:
# time.sleep(0.005)
# return
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
# data = self.thread_data_server.get_blinkData(50)
# fp1_data = data[0, :] # ch1 (相当于FP1)
# fp2_data = data[1, :] # ch2 (相当于FP2)
# for i in range(len(fp1_data)):
# self.fp1_buffer.append(fp1_data[i])
# self.fp2_buffer.append(fp2_data[i])
# self.total_samples += 1
# self.sample_counter += 1
if self.sample_counter >= self.step_samples:
self.process_blink_detection()
self.sample_counter = 0
# if self.sample_counter >= self.step_samples:
# self.process_blink_detection()
# self.sample_counter = 0
def stop(self):
'''