diff --git a/Decoder.py b/Decoder.py index c07dce6..3175b37 100644 --- a/Decoder.py +++ b/Decoder.py @@ -1,4 +1,6 @@ import ast +import os +import sys import threading from datetime import datetime import multiprocessing as mp @@ -8,7 +10,7 @@ import torch from queue import Empty from scipy import signal from torch.autograd import Variable -from Device.SunnyLinker import SunnyLinker64 +# from Device.SunnyLinker import SunnyLinker64 from SSMVEP.algorithm.tdca import TDCA from SSMVEP.algorithm.base import generate_cca_references from concentration.algorithm.calculate_focus import Calculate @@ -17,49 +19,70 @@ from Zmq.zmqServer import zmqServer from Zmq.zmqClient import zmqClient from MI.Algorithm.conformer_2class import onlineTrain from PubLibrary.InifileHelper import IniRead +from logs.log import algo_log from SSVEP.dwfbcca import FbccaDw -from Tools.plot_MI_EEG import plotMain +# from Tools.plot_MI_EEG import plotMain from collections import deque -class Decoder_main(threading.Thread, device_type): - def __init__(self, device_type=None): + +def get_root_path(): + """ + Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录) + """ + if getattr(sys, 'frozen', False): + # 打包后:返回 exe 所在目录 + return os.path.dirname(sys.executable) + else: + # 开发时:返回 py 文件所在目录 + return os.path.dirname(os.path.abspath(__file__)) +MODEL_FOLDER = "online_Models" + + +class Decoder_main(threading.Thread): + def __init__(self, device_info=None): threading.Thread.__init__(self) + self.device_info = { + 'sample_rate': device_info['sample_rate'], + 'frame_points': device_info['frame_points'], + 'channel_nums': device_info['channel_nums'], + 'channel_names': device_info['channel_names'], + 'channel_index': device_info['channel_index'], + } self.Runing=True self.decoder = None - - self.fs = 250 # 采样率 - self.energy = 0 # 电量 - self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常 self.decoder_class = None #解码器类别 + # 与采集设备通信的状态码,0为异常,1为正常 + # self.status_code = 0 + # self.device_info['sample_rate'] = 250 # 采样率 + # self.energy = 0 # 电量 + + self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果 - self.device_info = { - 'device_type': None, - 'sample_rate': None, - 'channel_num': None, - } - def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None): - self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type')) - _device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host')) - _device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port')) - _upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host')) - _upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port')) - - if self.DeviceType == 1: - self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp') - self.thread_data_server.host = _device_host - self.thread_data_server.port = _device_port - - self.thread_data_server.toUv = True - self.thread_data_server.start() - - self.zmqServer = zmqServer() + self.zmqServer = zmqServer(device_info=self.device_info) self.zmqServer.start() - self.zmqClient = zmqClient(_upper_host, _upper_port) - self.zmqClient.set_zmq_server(self.zmqServer) - self.zmqClient.connect() + # self.zmqClient = zmqClient(_upper_host, _upper_port) + # self.zmqClient.set_zmq_server(self.zmqServer) + # self.zmqClient.connect() + + + # def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None): + # self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type')) + # _device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host')) + # _device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port')) + # _upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host')) + # _upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port')) + + # if self.DeviceType == 1: + # self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.device_info['sample_rate'], 64, method='tcp') + # self.thread_data_server.host = _device_host + # self.thread_data_server.port = _device_port + + # self.thread_data_server.toUv = True + # self.thread_data_server.start() + def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号 # data: (chans, samples) @@ -76,26 +99,25 @@ class Decoder_main(threading.Thread, device_type): self.decoder_class = decoder_class if decoder_class == 'ssvep' or decoder_class == 'pvs': self.n_chan = 8 - self.thread_data_server.interval_inited = False + # self.thread_data_server.interval_inited = False DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue')) self.ListFreq = self.zmqServer.targetFreqs self.num_target = len(self.ListFreq) if self.num_target == 0: return # 初始化对象 二代算法 - self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5, + self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5, 0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method) # frequence band self.dw.filterFrequenceBank() self.dw.setNotchFilterPara() self.calculateCount = 0 - self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs), - 5) + self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5) self.dw.filterInit() self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 elif decoder_class == 'ssmvep': - self.thread_data_server.interval_init(decoder_class) + self.zmqServer.interval_init(decoder_class) self.n_chan = 8 self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 @@ -104,12 +126,12 @@ class Decoder_main(threading.Thread, device_type): self.list_freqs = np.array([8, 9]) # 刺激频率 self.list_phase = np.array([0, 0]) # 相位 self.tdca = TDCA(padding_len=5, n_components=1) - self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length, + self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length, phases=self.list_phase, n_harmonics=5) self.parameter_init(5,45) elif decoder_class == 'mi' or decoder_class == 'ma': - self.thread_data_server.interval_init(decoder_class) + self.zmqServer.interval_init(decoder_class) self.n_chan = 21 self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 @@ -124,7 +146,7 @@ class Decoder_main(threading.Thread, device_type): # self.win_len = 10 # self.win_step = 1 # self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue')) - # self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len) + # self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len) # self.interval_epoch = [0, 1] # self.parameter_init(2, 40) # # self.eegQueue moved to Calculate class @@ -136,8 +158,8 @@ class Decoder_main(threading.Thread, device_type): # self.total_samples = 0 # 总采样点数 # self.window_ms = 600 # 检测窗口大小 (ms) # self.step_ms = 100 # 滑动步长 (ms) - # self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点 - # self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点 + # self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点 + # self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点 # self.buffer_size = self.window_samples + self.step_samples * 5 # self.fp1_buffer = deque(maxlen=self.buffer_size) # self.fp2_buffer = deque(maxlen=self.buffer_size) @@ -151,11 +173,11 @@ class Decoder_main(threading.Thread, device_type): # self.double_blink_events = [] # 连续眨眼事件记录 # self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳 # self.blink_events = [] - # self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band') + # self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band') def parameter_init(self,bandPass_low,bandPass_high): - self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息 - self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch + self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息 + self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch self.trainData = [] #训练数据 self.trainLabel = [] #训练标签 self.plotData = [] #报告分析数据 @@ -163,10 +185,10 @@ class Decoder_main(threading.Thread, device_type): self.currentLabel = -1 #刺激界面当前显示的训练标签 self.train_started = False #是否开始训练模型 self.load_model = False # 调用模型是否完成的标志 - self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子 - self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器 + self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子 + self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器 fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S') - filePath = './online_Models/' + filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep self.modelPath = ''.join([filePath, fileName, '.pth']) self.mp_data_queue = mp.Queue() #多进程传参队列 self.mp_result_queue = mp.Queue() #多进程结果队列 @@ -192,54 +214,55 @@ class Decoder_main(threading.Thread, device_type): # 同步信息 if self.zmqServer.state_mode == 'sync': - self.zmqClient.send_to_all('sync', self.zmqClient.state) + # self.zmqClient.send_to_all('sync', self.zmqClient.state) self.zmqServer.state_mode = 'rest' - # 状态异常,报告上位机 - if self.status_code != self.thread_data_server.status_code: - self.status_code = self.thread_data_server.status_code - self.zmqClient.send_to_all('status_code', int(self.status_code)) - print('status code') - # 返回电量 - if self.energy != self.thread_data_server.energy: - self.energy = self.thread_data_server.energy - self.zmqClient.send_to_all('energy', int(self.energy)) - print('energy') + # # 状态异常,报告上位机 + # if self.status_code != self.thread_data_server.status_code: + # self.status_code = self.thread_data_server.status_code + # self.zmqClient.send_to_all('status_code', int(self.status_code)) + # print('status code') - if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次 - self.thread_data_server.Impedance(True) - print('Impedance') - self.zmqServer.open_Impedance = -1 - elif self.zmqServer.open_Impedance == False: - self.thread_data_server.Impedance(False) - self.zmqServer.open_Impedance = -1 + # # 返回电量 + # if self.energy != self.thread_data_server.energy: + # self.energy = self.thread_data_server.energy + # self.zmqClient.send_to_all('energy', int(self.energy)) + # print('energy') - if self.zmqServer.get_Impedance: # 返回阻抗值 - # print(self.zmqServer.get_Impedance) - # print(self.thread_data_server.GetDataLenCount()) - if self.thread_data_server.GetDataLenCount() > 250: - Impe_data = self.thread_data_server.getData(250) - # 计算阻抗 - imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class) - self.zmqClient.send_to_all('impedance', imps.tolist()) - else: - pass - if self.zmqServer.getReport: #返回训练报告内容 - self.zmqServer.getReport = False - allData = np.array(self.plotData) - allLabel = np.array(self.plotLabel) + 1 - nTrials = min(len(allLabel),len(allData)) - if nTrials < 30: - self.zmqClient.send_to_all('miReport',0) - else: - allData = allData[:nTrials] - allLabel = allLabel[:nTrials] - ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', - 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'] - compare_names = ['C3', 'CZ', 'C4'] - miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2, - fs=self.fs) - self.zmqClient.send_to_all('miReport',miReport) + # if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次 + # self.thread_data_server.Impedance(True) + # print('Impedance') + # self.zmqServer.open_Impedance = -1 + # elif self.zmqServer.open_Impedance == False: + # self.thread_data_server.Impedance(False) + # self.zmqServer.open_Impedance = -1 + + # if self.zmqServer.get_Impedance: # 返回阻抗值 + # # print(self.zmqServer.get_Impedance) + # # print(self.thread_data_server.GetDataLenCount()) + # if self.thread_data_server.GetDataLenCount() > 250: + # Impe_data = self.thread_data_server.getData(250) + # # 计算阻抗 + # imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class) + # self.zmqClient.send_to_all('impedance', imps.tolist()) + # else: + # pass + # if self.zmqServer.getReport: #返回训练报告内容 + # self.zmqServer.getReport = False + # allData = np.array(self.plotData) + # allLabel = np.array(self.plotLabel) + 1 + # nTrials = min(len(allLabel),len(allData)) + # if nTrials < 30: + # self.zmqClient.send_to_all('miReport',0) + # else: + # allData = allData[:nTrials] + # allLabel = allLabel[:nTrials] + # ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', + # 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'] + # compare_names = ['C3', 'CZ', 'C4'] + # miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2, + # fs=self.device_info['sample_rate']) + # self.zmqClient.send_to_all('miReport',miReport) # --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 --- @@ -250,34 +273,33 @@ class Decoder_main(threading.Thread, device_type): self.decoder_SSMVEP() elif self.decoder_class == 'mi': self.decoder_MI() - elif self.decoder_class == 'concentration': - self.decoder_concentration() - elif self.decoder_class == 'blink': - self.decoder_blink() - else: - if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 - if self.thread_data_server.GetDataLenCount() < 25: - time.sleep(0.005) - continue; - self.thread_data_server.getData(25) + # elif self.decoder_class == 'concentration': + # self.decoder_concentration() + # elif self.decoder_class == 'blink': + # self.decoder_blink() + # else: + # self. + # # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + # # if self.thread_data_server.GetDataLenCount() < 25: + # # time.sleep(0.005) + # # continue; + # # self.thread_data_server.getData(25) except Exception as e: - print(f"Decoder Loop Error: {e}") - import traceback - traceback.print_exc() + algo_log(f"Decoder Loop Error: {e}") time.sleep(0.1) # Prevent CPU spin if error is persistent def decoder_SSVEP(self): if self.zmqServer.StartDecode: self.zmqServer.StartDecode = False self.decodingSteps = 1 - self.thread_data_server.ResetAll() + self.zmqServer.paradigmBuffer.ResetAllPara() print('启动预测') - if self.thread_data_server.GetDataLenCount() < 50: + if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50: time.sleep(0.005) return if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码 return - data = self.thread_data_server.getDataViaSSVEP(50) + data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50) data = data[:self.n_chan, :] if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热 self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 @@ -293,7 +315,7 @@ class Decoder_main(threading.Thread, device_type): print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount)) self.calculateCount = 0 if self.decodingSteps == 3: # 发送解码后的信息 - self.zmqClient.send_to_all('result', int(choosenNum)) + self.zmqServer.broadcast_message('result', int(choosenNum)) self.decodingSteps = 0 print('发送给界面完成。') @@ -312,19 +334,19 @@ class Decoder_main(threading.Thread, device_type): formatted_time = now.strftime('%H:%M:%S.%f')[:-3] print('模型训练完成', formatted_time) self.load_model = True - self.zmqClient.send_to_all('paradigm', 1) + self.zmqServer.broadcast_message('paradigm', 1) '''训练阶段采集数据''' if self.zmqServer.state_mode == 'train': # 训练状态 if self.zmqServer.StartTrain: self.currentLabel = self.zmqServer.currentLabel self.zmqServer.StartTrain = False - if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ self.train_epoch[1] \ + self.thread_data_server.event_inner_idx: time.sleep(0.0001) return - print('训练队列数据:', self.thread_data_server.GetDataLenCount()) + print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount()) trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据 print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx]) trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 @@ -348,7 +370,7 @@ class Decoder_main(threading.Thread, device_type): formatted_time = now.strftime('%H:%M:%S.%f')[:-3] print('启动预测 ', formatted_time) - if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ self.interval_epoch[1] \ + self.thread_data_server.event_inner_idx: time.sleep(0.0001) @@ -360,8 +382,8 @@ class Decoder_main(threading.Thread, device_type): self.thread_data_server.event_inner_idx + self.interval_epoch[ 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]] pad_eeg_test = np.zeros( - (data.shape[0], int((self.sample_length + 0.1) * self.fs))) - pad_eeg_test[:, :int(self.sample_length * self.fs)] = data + (data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate']))) + pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data choosenNum, features_2 = self.decoder.predict(pad_eeg_test) if isinstance(choosenNum, np.ndarray): choosenNum = choosenNum[0] @@ -371,7 +393,7 @@ class Decoder_main(threading.Thread, device_type): print('发送给界面完成。') else: # 休息状态 if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 - if self.thread_data_server.GetDataLenCount() < 25: + if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: time.sleep(0.005) return self.thread_data_server.getData(25) @@ -419,12 +441,12 @@ class Decoder_main(threading.Thread, device_type): if self.zmqServer.StartTrain: self.currentLabel = self.zmqServer.currentLabel self.zmqServer.StartTrain = False - if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ self.interval_epoch[1] \ + self.thread_data_server.event_inner_idx: time.sleep(0.0001) return - print('训练队列数据:', self.thread_data_server.GetDataLenCount()) + print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount()) originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据 print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx]) trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 @@ -448,7 +470,7 @@ class Decoder_main(threading.Thread, device_type): formatted_time = now.strftime('%H:%M:%S.%f')[:-3] print('启动预测 ', formatted_time) - if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ self.interval_epoch[1] \ + self.thread_data_server.event_inner_idx: time.sleep(0.0001) @@ -477,128 +499,128 @@ class Decoder_main(threading.Thread, device_type): print(f'发送给界面完成,耗时{end - start:.3f}s。') else: # 休息状态 if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 - if self.thread_data_server.GetDataLenCount() < 25: + if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: time.sleep(0.005) return self.thread_data_server.getData(25) - def decoder_concentration(self): - if self.zmqServer.state_mode == 'predict': - if self.zmqServer.StartDecode: - self.zmqServer.StartDecode = False - self.thread_data_server.ResetAll() - now = datetime.now() - formatted_time = now.strftime('%H:%M:%S.%f')[:-3] - print('启动专注力预测 ', formatted_time) - if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果 - time.sleep(0.005) - return - if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码 - return - data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据 - result = self.calculate.queueOpt(data) - if result is not None: - self.zmqClient.send_to_all('result', int(result)) - else: # 休息状态 - if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 - if self.thread_data_server.GetDataLenCount() < 25: - time.sleep(0.005) - return - self.thread_data_server.getData(25) + # def decoder_concentration(self): + # if self.zmqServer.state_mode == 'predict': + # if self.zmqServer.StartDecode: + # self.zmqServer.StartDecode = False + # self.thread_data_server.ResetAll() + # now = datetime.now() + # formatted_time = now.strftime('%H:%M:%S.%f')[:-3] + # print('启动专注力预测 ', formatted_time) + # if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果 + # time.sleep(0.005) + # return + # if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码 + # return + # data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据 + # result = self.calculate.queueOpt(data) + # if result is not None: + # self.zmqClient.send_to_all('result', int(result)) + # else: # 休息状态 + # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + # if self.thread_data_server.GetDataLenCount() < 25: + # time.sleep(0.005) + # return + # self.thread_data_server.getData(25) #### Blink detection ##### - def check_double_blink(self, current_time): - """ - 检查是否检测到连续两次眨眼 - @param current_time: 当前眨眼时间戳 - @return: True表示检测到连续两次眨眼 - """ - if len(self.blink_timestamps) < 2: - return False + # def check_double_blink(self, current_time): + # """ + # 检查是否检测到连续两次眨眼 + # @param current_time: 当前眨眼时间戳 + # @return: True表示检测到连续两次眨眼 + # """ + # if len(self.blink_timestamps) < 2: + # return False - # 检查是否在去抖期内 - if self.last_double_blink_time > 0: - time_since_last_double_blink = current_time - self.last_double_blink_time - if time_since_last_double_blink < self.double_blink_jitter: - return False # 在去抖期内,忽略连续眨眼检测 - last_time = self.blink_timestamps[-1] # 当前眨眼 - prev_time = self.blink_timestamps[-2] # 上次眨眼 + # # 检查是否在去抖期内 + # if self.last_double_blink_time > 0: + # time_since_last_double_blink = current_time - self.last_double_blink_time + # if time_since_last_double_blink < self.double_blink_jitter: + # return False # 在去抖期内,忽略连续眨眼检测 + # last_time = self.blink_timestamps[-1] # 当前眨眼 + # prev_time = self.blink_timestamps[-2] # 上次眨眼 - interval = last_time - prev_time - if interval <= self.double_blink_interval: - return True + # interval = last_time - prev_time + # if interval <= self.double_blink_interval: + # return True - return False + # return False - def process_blink_detection(self): - """ - 在缓冲区数据上执行,单次眨眼检测 - """ - if len(self.fp1_buffer) < self.window_samples: - return + # def process_blink_detection(self): + # """ + # 在缓冲区数据上执行,单次眨眼检测 + # """ + # if len(self.fp1_buffer) < self.window_samples: + # return - fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:]) - fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:]) - # 计算FP1和FP2的平均 - fp12_mean = (fp1_data + fp2_data) / 2.0 - # 带通滤波 - try: - fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean) - except Exception as e: - print(f"Filter error: {e}") - return - F = np.diff(fp12_filtered) - if len(F) < 3: - return - b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax) + # fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:]) + # fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:]) + # # 计算FP1和FP2的平均 + # fp12_mean = (fp1_data + fp2_data) / 2.0 + # # 带通滤波 + # try: + # fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean) + # except Exception as e: + # print(f"Filter error: {e}") + # return + # F = np.diff(fp12_filtered) + # if len(F) < 3: + # return + # b, d, e = blink_detection(F, self.device_info['sample_rate'], self.Dmin, self.Dmax, self.EMin, self.EMax) - if b == 1: - samples_since_last = self.total_samples - self.last_blink_time - time_since_last_ms = (samples_since_last / self.fs) * 1000 - if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms - self.blink_count += 1 - self.last_blink_time = self.total_samples - current_time = time.time() - self.blink_timestamps.append(current_time) - blink_event = { - 'count': self.blink_count, - 'time': current_time, - 'sample_index': self.total_samples, - 'duration_ms': d, - 'energy': e - } - self.blink_events.append(blink_event) - self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机 - if self.check_double_blink(current_time): - self.double_blink_count += 1 - interval = self.blink_timestamps[-1] - self.blink_timestamps[-2] - double_blink_event = { - 'double_blink_count': self.double_blink_count, - 'blink1_time': self.blink_timestamps[-2], - 'blink2_time': self.blink_timestamps[-1], - 'interval': interval - } - self.double_blink_events.append(double_blink_event) - self.last_double_blink_time = current_time - self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件 + # if b == 1: + # samples_since_last = self.total_samples - self.last_blink_time + # time_since_last_ms = (samples_since_last / self.device_info['sample_rate']) * 1000 + # if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms + # self.blink_count += 1 + # self.last_blink_time = self.total_samples + # current_time = time.time() + # self.blink_timestamps.append(current_time) + # blink_event = { + # 'count': self.blink_count, + # 'time': current_time, + # 'sample_index': self.total_samples, + # 'duration_ms': d, + # 'energy': e + # } + # self.blink_events.append(blink_event) + # self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机 + # if self.check_double_blink(current_time): + # self.double_blink_count += 1 + # interval = self.blink_timestamps[-1] - self.blink_timestamps[-2] + # double_blink_event = { + # 'double_blink_count': self.double_blink_count, + # 'blink1_time': self.blink_timestamps[-2], + # 'blink2_time': self.blink_timestamps[-1], + # 'interval': interval + # } + # self.double_blink_events.append(double_blink_event) + # self.last_double_blink_time = current_time + # self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件 - def decoder_blink(self): - if self.thread_data_server.GetDataLenCount() < 50: - time.sleep(0.005) - return - if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 - data = self.thread_data_server.get_blinkData(50) - fp1_data = data[0, :] # ch1 (相当于FP1) - fp2_data = data[1, :] # ch2 (相当于FP2) - for i in range(len(fp1_data)): - self.fp1_buffer.append(fp1_data[i]) - self.fp2_buffer.append(fp2_data[i]) - self.total_samples += 1 - self.sample_counter += 1 + # def decoder_blink(self): + # if self.thread_data_server.GetDataLenCount() < 50: + # time.sleep(0.005) + # return + # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + # data = self.thread_data_server.get_blinkData(50) + # fp1_data = data[0, :] # ch1 (相当于FP1) + # fp2_data = data[1, :] # ch2 (相当于FP2) + # for i in range(len(fp1_data)): + # self.fp1_buffer.append(fp1_data[i]) + # self.fp2_buffer.append(fp2_data[i]) + # self.total_samples += 1 + # self.sample_counter += 1 - if self.sample_counter >= self.step_samples: - self.process_blink_detection() - self.sample_counter = 0 + # if self.sample_counter >= self.step_samples: + # self.process_blink_detection() + # self.sample_counter = 0 def stop(self): ''' diff --git a/Zmq/dataBuffer.py b/Zmq/dataBuffer.py index 9bba7e6..a9016bc 100644 --- a/Zmq/dataBuffer.py +++ b/Zmq/dataBuffer.py @@ -63,8 +63,53 @@ class ParadigmRingBuffer: 获取最新缓存中每个通道的数量 @return: ''' - return self.nUpdate + return self.nUpdate + + # ========== 各范式数据访问接口 ========== + def get_MIData(self): + """获取MI导联数据 (21通道 + 事件)""" + data = self.getData(self.GetDataLenCount()) + rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65] + row_to_select = np.array(rows_to_extract) + if data.shape[1] > 0: + return data[row_to_select, :] + return np.zeros((len(rows_to_extract), 0)) + def get_SSMVEPData(self): + """获取SSMVEP导联数据 (8通道 + 事件)""" + data = self.getData(self.GetDataLenCount()) + rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65] + row_to_select = np.array(rows_to_extract) + if data.shape[1] > 0: + return data[row_to_select, :] + return np.zeros((len(rows_to_extract), 0)) + + def getDataViaSSVEP(self, count): + """获取SSVEP数据 (8通道 + 事件)""" + data = self.getData(count) + rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64] + row_to_select = np.array(rows_to_extract) + if data.shape[1] > 0: + return data[row_to_select, :] + return np.zeros((len(rows_to_extract), 0)) + + def get_concentrateData(self, count): + """获取专注力数据 (2通道)""" + data = self.getData(count) + rows_to_extract = [0, 1] + row_to_select = np.array(rows_to_extract) + if data.shape[1] > 0: + return data[row_to_select, :] + return np.zeros((len(rows_to_extract), 0)) + + def get_blinkData(self, count): + """获取眨眼数据 (2通道)""" + data = self.getData(count) + rows_to_extract = [0, 1] + row_to_select = np.array(rows_to_extract) + if data.shape[1] > 0: + return data[row_to_select, :] + return np.zeros((len(rows_to_extract), 0)) # reset buffer def resetAllPara(self): @@ -72,6 +117,4 @@ class ParadigmRingBuffer: self.currentPtr = 0 self.readPtr = 0 # add by lizhenhua 清空读指针 self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区 - - diff --git a/Zmq/zmqServer.py b/Zmq/zmqServer.py index c071bd1..fbf0eac 100644 --- a/Zmq/zmqServer.py +++ b/Zmq/zmqServer.py @@ -1,16 +1,22 @@ +import ast import numpy as np -import zmq import threading import json import queue +from typing import Dict # from Device.SunnyLinker import SunnyLinker64 from dataBuffer import ParadigmRingBuffer from filterProcess import FilterRingBuffer +from PubLibrary.InifileHelper import IniRead from logs.log import algo_log +import zmq + class zmqServer(threading.Thread): def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None): threading.Thread.__init__(self) + self.device_info = device_info + self.host = host self.cmd_port = cmd_port # 命令交互端口 self.data_port = data_port # 数据接收端口 @@ -28,8 +34,8 @@ class zmqServer(threading.Thread): self.daemon = True # 范式数据缓存 - self.paradigmBuffer = ParadigmRingBuffer(66, 2500) - self.filterBuffer = FilterRingBuffer(66, 2500) + self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10) + self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10) # 命令与数据通信 @@ -64,6 +70,77 @@ class zmqServer(threading.Thread): self.cmd_clients = set() # 命令端口客户端ID self.data_clients = set() # 数据端口客户端ID self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播) + + + # 范式buffer参数, 事件检测相关 + self._event_lock = threading.Lock() + self._epoch_finished = False + self._event_inner_idx = -1 + self.pack_contain_event = False + self.predict_event = 99 + self.events = [1, 2, self.predict_event] + self.count_events = {} + self.latency = 50 + self.train_latency = 50 + self._interval_inited = False + + @property + def interval_inited(self): + return self._interval_inited + + @interval_inited.setter + def interval_inited(self, value): + self._interval_inited = value + + @property + def epoch_finished(self): + with self._event_lock: + return self._epoch_finished + + @epoch_finished.setter + def epoch_finished(self, value): + with self._event_lock: + self._epoch_finished = value + + @property + def event_inner_idx(self): + with self._event_lock: + return self._event_inner_idx + + @event_inner_idx.setter + def event_inner_idx(self, value): + with self._event_lock: + self._event_inner_idx = value + + def interval_init(self, decoder_class): + if decoder_class == 'ssmvep': + interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) + self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息 + self.train_epoch = [int(self.interval_epoch[0]), + int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch + self.latency = (self.interval_epoch[ + 1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉 + self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 + + elif decoder_class == 'mi': + interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) + self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息 + self.train_epoch = self.interval_epoch.copy() + self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点; + self.train_latency = self.latency + + print('时间窗:', (interval_epoch)) + self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息 + self.event_inner_idx = -1 # event在5位数据包内部的idx + self.epoch_finished = False # 接收epoch是否完整 + self.pack_contain_event = False # 当前包是否含有event + self.predict_event = 99 + self.events = [1, 2, self.predict_event] + self.interval_inited = True + # if getattr(self, 'serial', None) and self.serial.is_open: + # self.serial.close() + # self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口 + def broadcast_message(self, method, params): """Put message into queue to be sent to all command clients""" @@ -78,15 +155,15 @@ class zmqServer(threading.Thread): # 注册新的命令客户端 if ident not in self.cmd_clients: self.cmd_clients.add(ident) - print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})") + algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})") # 解析消息 try: message = json.loads(message_bytes.decode('utf-8')) except json.JSONDecodeError: - print(f"Invalid JSON from CMD client {ident}") - continue - print(f"Received CMD request: {message}") + algo_log(f"Invalid JSON from CMD client {ident}") + return + algo_log(f"Received CMD request: {message}") method = message.get("method") params = message.get("params") @@ -94,37 +171,40 @@ class zmqServer(threading.Thread): # 原有命令处理逻辑 if method == "sync": self.state_mode = 'sync' - if method == "targetFreqs": + elif method == "targetFreqs": if not isinstance(params, list): - print('targetFreqs must be a list') - continue + algo_log(f"targetFreqs must be a list") + return if params != self.targetFreqs: self.targetFreqs = params self.changeTarget = True - if method == "decoderClass": + elif method == "decoderClass": if not isinstance(params, str): - print('decoderClass must be a str') - continue + algo_log(f"decoderClass must be a str") + return if params != self.decoder_class: self.decoder_class = params self.decoder_switch = True - if method == "getReport": - self.getReport = True - if method == "train":#训练状态 + elif method == "train":#训练状态 self.state_mode = 'train' self.StartTrain = True self.currentLabel = params # 当前刺激端的训练标签 - self.sunnyLinker.push_trigger(self.labels[self.currentLabel]) + # self.sunnyLinker.push_trigger(self.labels[self.currentLabel]) elif method == "predict":#预测状态 self.state_mode = 'predict' if params == 1: #开始解码 self.StartDecode = True - self.sunnyLinker.push_trigger(0x63) + # self.sunnyLinker.push_trigger(0x63) elif params == 2: #停止解码 self.IsExitApp = True self.running = False elif method == "rest": #休息状态 self.state_mode = 'rest' + else: + algo_log(f"未知命令:{method}", level="WARNING") + + # elif method == "getReport": + # self.getReport = True # elif method == "impedance": # if params == 1: # self.open_Impedance = True # 开启阻抗 @@ -153,7 +233,7 @@ class zmqServer(threading.Thread): try: # 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节,与int32字节数相同) - EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节 + EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节 if len(data_bytes) != EXPECTED_BYTES: print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节") return @@ -162,7 +242,7 @@ class zmqServer(threading.Thread): # 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式 data_np = np.frombuffer(data_bytes, dtype=np.float32) # 重塑为上位机原始维度 - data_np = data_np.reshape(5, 66) + data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums']) # 转置为(通道数, 采样点数)标准格式,转换为float64保证滤波运算精度 data_np = data_np.T.astype(np.float64) @@ -215,7 +295,7 @@ class zmqServer(threading.Thread): self._process_send_queue() # 2. 轮训监听两个Socket的输入事件(10ms超时,避免阻塞) - socks = dict(self.poller.poll(10)) + socks = dict(self.poller.poll(50)) # 处理命令端口消息 if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN: diff --git a/config.ini b/config.ini index c90639f..1970d9b 100644 --- a/config.ini +++ b/config.ini @@ -24,10 +24,11 @@ console_output = 1 ; 64 导设备配置 [device_type_1] -device_sample_rate = 250 -device_channel_nums = 66 -device_channel_names = ['FP1', 'FP2', 'FC1', 'FC2', 'CP1', 'CP2', 'F3', 'F4', 'P3', 'P4', 'O1', 'O2', 'FT9', 'FT10', 'F7', 'F8', 'TP9', 'TP10', 'AF4', 'PO8', 'PZ', 'FCZ'] -device_channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18] +sample_rate = 250 +frame_points = 5 +channel_nums = 66 +channel_names = ['FP1', 'FP2', 'FC1', 'FC2', 'CP1', 'CP2', 'F3', 'F4', 'P3', 'P4', 'O1', 'O2', 'FT9', 'FT10', 'F7', 'F8', 'TP9', 'TP10', 'AF4', 'PO8', 'PZ', 'FCZ'] +channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18] diff --git a/runDecoder.py b/runDecoder.py index 71cdf92..7817b1e 100644 --- a/runDecoder.py +++ b/runDecoder.py @@ -9,29 +9,28 @@ from PubLibrary.InifileHelper import IniRead def get_device_info(device_type): - section = f'device_type_{device_type}' device_info = { - 'device_sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250, - - '' + 'sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250, + 'channel_nums': int(IniRead(section, 'channel_nums')) if IniRead(section, 'channel_nums') is not None else 66, + 'channel_names': IniRead(section, 'channel_names') if IniRead(section, 'channel_names') is not None else None, + 'channel_index': IniRead(section, 'channel_index') if IniRead(section, 'channel_index') is not None else None, } + + return device_info + if __name__ == "__main__": if not is_program_running(): # 解析命令行参数 - parser = argparse.ArgumentParser(description="EEG Decoder Application") - parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type") + # parser = argparse.ArgumentParser(description="EEG Decoder Application") + # parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type") # parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP") # parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port") # parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP") # parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port") + # args = parser.parse_args() - args = parser.parse_args() - device_info= get_device_info(args.device_type) - - - decoder = Decoder_main(device_info=device_info) # decoder.connect( # device_type=args.device_type, # device_host=args.device_host, @@ -40,6 +39,9 @@ if __name__ == "__main__": # upper_port=args.upper_port # ) + device_info= get_device_info(1) + decoder = Decoder_main(device_info=device_info) + try: decoder.start() while not decoder.zmqServer.IsExitApp: