Compare commits

..

11 Commits

Author SHA1 Message Date
Ivey Song
30c690e4e3 Merge branch 'master' of http://47.98.56.110:7001/lizhao/bci_algo 2026-06-06 16:08:31 +08:00
Ivey Song
853037726d update 2026-06-06 16:05:32 +08:00
8a9d9a5c78 update log 2026-06-06 16:04:33 +08:00
29b6118f11 v1 2026-06-06 15:53:50 +08:00
fce7d93d5e delete server1 2026-06-06 15:15:10 +08:00
949801198e update 2026-06-06 15:13:23 +08:00
Ivey Song
494515463d 专注力计算 2026-06-06 14:57:52 +08:00
Ivey Song
9a655ffdeb 删除以往pth模型 2026-06-06 14:51:00 +08:00
Ivey Song
a9fd51e935 brainmap 2026-06-06 14:50:16 +08:00
Ivey Song
4b7e48be38 数据模拟 2026-06-06 14:49:38 +08:00
2d190d6431 add buffer 2026-06-06 14:40:07 +08:00
9 changed files with 688 additions and 825 deletions

View File

@@ -1,6 +1,7 @@
import ast import ast
import glob import glob
import os import os
import sys
import threading import threading
from datetime import datetime from datetime import datetime
import multiprocessing as mp import multiprocessing as mp
@@ -10,7 +11,7 @@ import torch
from queue import Empty from queue import Empty
from scipy import signal from scipy import signal
from torch.autograd import Variable from torch.autograd import Variable
from Device.SunnyLinker import SunnyLinker64 # from Device.SunnyLinker import SunnyLinker64
from SSMVEP.algorithm.tdca import TDCA from SSMVEP.algorithm.tdca import TDCA
from SSMVEP.algorithm.base import generate_cca_references from SSMVEP.algorithm.base import generate_cca_references
from concentration.algorithm.calculate_focus import Calculate from concentration.algorithm.calculate_focus import Calculate
@@ -19,49 +20,70 @@ from Zmq.zmqServer import zmqServer
from Zmq.zmqClient import zmqClient from Zmq.zmqClient import zmqClient
from MI.Algorithm.conformer_2class import onlineTrain from MI.Algorithm.conformer_2class import onlineTrain
from PubLibrary.InifileHelper import IniRead from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log
from SSVEP.dwfbcca import FbccaDw from SSVEP.dwfbcca import FbccaDw
from Tools.plot_MI_EEG import plotMain # from Tools.plot_MI_EEG import plotMain
from collections import deque from collections import deque
class Decoder_main(threading.Thread, device_type):
def __init__(self, device_type=None): def get_root_path():
"""
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
"""
if getattr(sys, 'frozen', False):
# 打包后:返回 exe 所在目录
return os.path.dirname(sys.executable)
else:
# 开发时:返回 py 文件所在目录
return os.path.dirname(os.path.abspath(__file__))
MODEL_FOLDER = "online_Models"
class Decoder_main(threading.Thread):
def __init__(self, device_info=None):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.device_info = {
'sample_rate': device_info['sample_rate'],
'frame_points': device_info['frame_points'],
'channel_nums': device_info['channel_nums'],
'channel_names': device_info['channel_names'],
'channel_index': device_info['channel_index'],
}
self.Runing=True self.Runing=True
self.decoder = None self.decoder = None
self.fs = 250 # 采样率
self.energy = 0 # 电量
self.status_code = 0 # 与采集设备通信的状态码0为异常1为正常
self.decoder_class = None #解码器类别 self.decoder_class = None #解码器类别
# 与采集设备通信的状态码0为异常1为正常
# self.status_code = 0
# self.device_info['sample_rate'] = 250 # 采样率
# self.energy = 0 # 电量
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果 self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
self.device_info = {
'device_type': None,
'sample_rate': None,
'channel_num': None,
}
def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None): self.zmqServer = zmqServer(device_info=self.device_info)
self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
_device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
_device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
_upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
_upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
if self.DeviceType == 1:
self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp')
self.thread_data_server.host = _device_host
self.thread_data_server.port = _device_port
self.thread_data_server.toUv = True
self.thread_data_server.start()
self.zmqServer = zmqServer()
self.zmqServer.start() self.zmqServer.start()
self.zmqClient = zmqClient(_upper_host, _upper_port) # self.zmqClient = zmqClient(_upper_host, _upper_port)
self.zmqClient.set_zmq_server(self.zmqServer) # self.zmqClient.set_zmq_server(self.zmqServer)
self.zmqClient.connect() # self.zmqClient.connect()
# def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
# self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
# _device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
# _device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
# _upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
# _upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
# if self.DeviceType == 1:
# self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.device_info['sample_rate'], 64, method='tcp')
# self.thread_data_server.host = _device_host
# self.thread_data_server.port = _device_port
# self.thread_data_server.toUv = True
# self.thread_data_server.start()
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号 def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples) # data: (chans, samples)
@@ -78,26 +100,25 @@ class Decoder_main(threading.Thread, device_type):
self.decoder_class = decoder_class self.decoder_class = decoder_class
if decoder_class == 'ssvep' or decoder_class == 'pvs': if decoder_class == 'ssvep' or decoder_class == 'pvs':
self.n_chan = 8 self.n_chan = 8
self.thread_data_server.interval_inited = False # self.thread_data_server.interval_inited = False
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue')) DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
self.ListFreq = self.zmqServer.targetFreqs self.ListFreq = self.zmqServer.targetFreqs
self.num_target = len(self.ListFreq) self.num_target = len(self.ListFreq)
if self.num_target == 0: if self.num_target == 0:
return return
# 初始化对象 二代算法 # 初始化对象 二代算法
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5, self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method) 0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
# frequence band # frequence band
self.dw.filterFrequenceBank() self.dw.filterFrequenceBank()
self.dw.setNotchFilterPara() self.dw.setNotchFilterPara()
self.calculateCount = 0 self.calculateCount = 0
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs), self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
5)
self.dw.filterInit() self.dw.filterInit()
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
elif decoder_class == 'ssmvep': elif decoder_class == 'ssmvep':
self.thread_data_server.interval_init(decoder_class) self.zmqServer.interval_init(decoder_class)
self.n_chan = 8 self.n_chan = 8
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
@@ -106,12 +127,12 @@ class Decoder_main(threading.Thread, device_type):
self.list_freqs = np.array([8, 9]) # 刺激频率 self.list_freqs = np.array([8, 9]) # 刺激频率
self.list_phase = np.array([0, 0]) # 相位 self.list_phase = np.array([0, 0]) # 相位
self.tdca = TDCA(padding_len=5, n_components=1) self.tdca = TDCA(padding_len=5, n_components=1)
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length, self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
phases=self.list_phase, n_harmonics=5) phases=self.list_phase, n_harmonics=5)
self.parameter_init(5,45) self.parameter_init(5,45)
elif decoder_class == 'mi' or decoder_class == 'ma': elif decoder_class == 'mi' or decoder_class == 'ma':
self.thread_data_server.interval_init(decoder_class) self.zmqServer.interval_init(decoder_class)
self.n_chan = 21 self.n_chan = 21
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
@@ -126,7 +147,7 @@ class Decoder_main(threading.Thread, device_type):
# self.win_len = 10 # self.win_len = 10
# self.win_step = 1 # self.win_step = 1
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue')) # self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len) # self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
# self.interval_epoch = [0, 1] # self.interval_epoch = [0, 1]
# self.parameter_init(2, 40) # self.parameter_init(2, 40)
# # self.eegQueue moved to Calculate class # # self.eegQueue moved to Calculate class
@@ -138,8 +159,8 @@ class Decoder_main(threading.Thread, device_type):
# self.total_samples = 0 # 总采样点数 # self.total_samples = 0 # 总采样点数
# self.window_ms = 600 # 检测窗口大小 (ms) # self.window_ms = 600 # 检测窗口大小 (ms)
# self.step_ms = 100 # 滑动步长 (ms) # self.step_ms = 100 # 滑动步长 (ms)
# self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点 # self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
# self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点 # self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
# self.buffer_size = self.window_samples + self.step_samples * 5 # self.buffer_size = self.window_samples + self.step_samples * 5
# self.fp1_buffer = deque(maxlen=self.buffer_size) # self.fp1_buffer = deque(maxlen=self.buffer_size)
# self.fp2_buffer = deque(maxlen=self.buffer_size) # self.fp2_buffer = deque(maxlen=self.buffer_size)
@@ -153,11 +174,11 @@ class Decoder_main(threading.Thread, device_type):
# self.double_blink_events = [] # 连续眨眼事件记录 # self.double_blink_events = [] # 连续眨眼事件记录
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳 # self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
# self.blink_events = [] # self.blink_events = []
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band') # self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
def parameter_init(self,bandPass_low,bandPass_high): def parameter_init(self,bandPass_low,bandPass_high):
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息 self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
self.trainData = [] #训练数据 self.trainData = [] #训练数据
self.trainLabel = [] #训练标签 self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据 self.plotData = [] #报告分析数据
@@ -165,10 +186,10 @@ class Decoder_main(threading.Thread, device_type):
self.currentLabel = -1 #刺激界面当前显示的训练标签 self.currentLabel = -1 #刺激界面当前显示的训练标签
self.train_started = False #是否开始训练模型 self.train_started = False #是否开始训练模型
self.load_model = False # 调用模型是否完成的标志 self.load_model = False # 调用模型是否完成的标志
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波250是采样率30是质量因子 self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波250是采样率30是质量因子
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器 self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S') fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
filePath = './online_Models/' filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
for old_pth in glob.glob(os.path.join(filePath, '*.pth')): for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
os.remove(old_pth) os.remove(old_pth)
self.modelPath = ''.join([filePath, fileName, '.pth']) self.modelPath = ''.join([filePath, fileName, '.pth'])
@@ -196,54 +217,55 @@ class Decoder_main(threading.Thread, device_type):
# 同步信息 # 同步信息
if self.zmqServer.state_mode == 'sync': if self.zmqServer.state_mode == 'sync':
self.zmqClient.send_to_all('sync', self.zmqClient.state) # self.zmqClient.send_to_all('sync', self.zmqClient.state)
self.zmqServer.state_mode = 'rest' self.zmqServer.state_mode = 'rest'
# 状态异常,报告上位机
if self.status_code != self.thread_data_server.status_code:
self.status_code = self.thread_data_server.status_code
self.zmqClient.send_to_all('status_code', int(self.status_code))
print('status code')
# 返回电量 # # 状态异常,报告上位机
if self.energy != self.thread_data_server.energy: # if self.status_code != self.thread_data_server.status_code:
self.energy = self.thread_data_server.energy # self.status_code = self.thread_data_server.status_code
self.zmqClient.send_to_all('energy', int(self.energy)) # self.zmqClient.send_to_all('status_code', int(self.status_code))
print('energy') # print('status code')
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次 # # 返回电量
self.thread_data_server.Impedance(True) # if self.energy != self.thread_data_server.energy:
print('Impedance') # self.energy = self.thread_data_server.energy
self.zmqServer.open_Impedance = -1 # self.zmqClient.send_to_all('energy', int(self.energy))
elif self.zmqServer.open_Impedance == False: # print('energy')
self.thread_data_server.Impedance(False)
self.zmqServer.open_Impedance = -1
if self.zmqServer.get_Impedance: # 返回阻抗值 # if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
# print(self.zmqServer.get_Impedance) # self.thread_data_server.Impedance(True)
# print(self.thread_data_server.GetDataLenCount()) # print('Impedance')
if self.thread_data_server.GetDataLenCount() > 250: # self.zmqServer.open_Impedance = -1
Impe_data = self.thread_data_server.getData(250) # elif self.zmqServer.open_Impedance == False:
# 计算阻抗 # self.thread_data_server.Impedance(False)
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class) # self.zmqServer.open_Impedance = -1
self.zmqClient.send_to_all('impedance', imps.tolist())
else: # if self.zmqServer.get_Impedance: # 返回阻抗值
pass # # print(self.zmqServer.get_Impedance)
if self.zmqServer.getReport: #返回训练报告内容 # # print(self.thread_data_server.GetDataLenCount())
self.zmqServer.getReport = False # if self.thread_data_server.GetDataLenCount() > 250:
allData = np.array(self.plotData) # Impe_data = self.thread_data_server.getData(250)
allLabel = np.array(self.plotLabel) + 1 # # 计算阻抗
nTrials = min(len(allLabel),len(allData)) # imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
if nTrials < 30: # self.zmqClient.send_to_all('impedance', imps.tolist())
self.zmqClient.send_to_all('miReport',0) # else:
else: # pass
allData = allData[:nTrials] # if self.zmqServer.getReport: #返回训练报告内容
allLabel = allLabel[:nTrials] # self.zmqServer.getReport = False
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', # allData = np.array(self.plotData)
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'] # allLabel = np.array(self.plotLabel) + 1
compare_names = ['C3', 'CZ', 'C4'] # nTrials = min(len(allLabel),len(allData))
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2, # if nTrials < 30:
fs=self.fs) # self.zmqClient.send_to_all('miReport',0)
self.zmqClient.send_to_all('miReport',miReport) # else:
# allData = allData[:nTrials]
# allLabel = allLabel[:nTrials]
# ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
# 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
# compare_names = ['C3', 'CZ', 'C4']
# miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
# fs=self.device_info['sample_rate'])
# self.zmqClient.send_to_all('miReport',miReport)
# --- 取数优先:先执行 decoder消费环形缓冲再处理 plot/report 等重负载 --- # --- 取数优先:先执行 decoder消费环形缓冲再处理 plot/report 等重负载 ---
@@ -254,34 +276,33 @@ class Decoder_main(threading.Thread, device_type):
self.decoder_SSMVEP() self.decoder_SSMVEP()
elif self.decoder_class == 'mi': elif self.decoder_class == 'mi':
self.decoder_MI() self.decoder_MI()
elif self.decoder_class == 'concentration': # elif self.decoder_class == 'concentration':
self.decoder_concentration() # self.decoder_concentration()
elif self.decoder_class == 'blink': # elif self.decoder_class == 'blink':
self.decoder_blink() # self.decoder_blink()
else: # else:
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 # self.
if self.thread_data_server.GetDataLenCount() < 25: # # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
time.sleep(0.005) # # if self.thread_data_server.GetDataLenCount() < 25:
continue; # # time.sleep(0.005)
self.thread_data_server.getData(25) # # continue;
# # self.thread_data_server.getData(25)
except Exception as e: except Exception as e:
print(f"Decoder Loop Error: {e}") algo_log(f"Decoder Loop Error: {e}")
import traceback
traceback.print_exc()
time.sleep(0.1) # Prevent CPU spin if error is persistent time.sleep(0.1) # Prevent CPU spin if error is persistent
def decoder_SSVEP(self): def decoder_SSVEP(self):
if self.zmqServer.StartDecode: if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False self.zmqServer.StartDecode = False
self.decodingSteps = 1 self.decodingSteps = 1
self.thread_data_server.ResetAll() self.zmqServer.paradigmBuffer.ResetAllPara()
print('启动预测') print('启动预测')
if self.thread_data_server.GetDataLenCount() < 50: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
time.sleep(0.005) time.sleep(0.005)
return return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码 if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return return
data = self.thread_data_server.getDataViaSSVEP(50) data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
data = data[:self.n_chan, :] data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热 if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
@@ -297,7 +318,7 @@ class Decoder_main(threading.Thread, device_type):
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount)) print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
self.calculateCount = 0 self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息 if self.decodingSteps == 3: # 发送解码后的信息
self.zmqClient.send_to_all('result', int(choosenNum)) self.zmqServer.broadcast_message('result', int(choosenNum))
self.decodingSteps = 0 self.decodingSteps = 0
print('发送给界面完成。') print('发送给界面完成。')
@@ -316,19 +337,19 @@ class Decoder_main(threading.Thread, device_type):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('模型训练完成', formatted_time) print('模型训练完成', formatted_time)
self.load_model = True self.load_model = True
self.zmqClient.send_to_all('paradigm', 1) self.zmqServer.broadcast_message('paradigm', 1)
'''训练阶段采集数据''' '''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态 if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.StartTrain: if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False self.zmqServer.StartTrain = False
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.train_epoch[1] \ self.train_epoch[1] \
+ self.thread_data_server.event_inner_idx: + self.thread_data_server.event_inner_idx:
time.sleep(0.0001) time.sleep(0.0001)
return return
print('训练队列数据:', self.thread_data_server.GetDataLenCount()) print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据 trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx]) print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
@@ -352,7 +373,7 @@ class Decoder_main(threading.Thread, device_type):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time) print('启动预测 ', formatted_time)
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \ self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx: + self.thread_data_server.event_inner_idx:
time.sleep(0.0001) time.sleep(0.0001)
@@ -364,8 +385,8 @@ class Decoder_main(threading.Thread, device_type):
self.thread_data_server.event_inner_idx + self.interval_epoch[ self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]] 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
pad_eeg_test = np.zeros( pad_eeg_test = np.zeros(
(data.shape[0], int((self.sample_length + 0.1) * self.fs))) (data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
choosenNum, features_2 = self.decoder.predict(pad_eeg_test) choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray): if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0] choosenNum = choosenNum[0]
@@ -375,7 +396,7 @@ class Decoder_main(threading.Thread, device_type):
print('发送给界面完成。') print('发送给界面完成。')
else: # 休息状态 else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005) time.sleep(0.005)
return return
self.thread_data_server.getData(25) self.thread_data_server.getData(25)
@@ -423,12 +444,12 @@ class Decoder_main(threading.Thread, device_type):
if self.zmqServer.StartTrain: if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False self.zmqServer.StartTrain = False
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \ self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx: + self.thread_data_server.event_inner_idx:
time.sleep(0.0001) time.sleep(0.0001)
return return
print('训练队列数据:', self.thread_data_server.GetDataLenCount()) print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据 originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx]) print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
@@ -452,7 +473,7 @@ class Decoder_main(threading.Thread, device_type):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time) print('启动预测 ', formatted_time)
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \ self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx: + self.thread_data_server.event_inner_idx:
time.sleep(0.0001) time.sleep(0.0001)
@@ -481,128 +502,128 @@ class Decoder_main(threading.Thread, device_type):
print(f'发送给界面完成,耗时{end - start:.3f}s。') print(f'发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态 else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005) time.sleep(0.005)
return return
self.thread_data_server.getData(25) self.thread_data_server.getData(25)
def decoder_concentration(self): # def decoder_concentration(self):
if self.zmqServer.state_mode == 'predict': # if self.zmqServer.state_mode == 'predict':
if self.zmqServer.StartDecode: # if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False # self.zmqServer.StartDecode = False
self.thread_data_server.ResetAll() # self.thread_data_server.ResetAll()
now = datetime.now() # now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] # formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动专注力预测 ', formatted_time) # print('启动专注力预测 ', formatted_time)
if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果 # if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
time.sleep(0.005) # time.sleep(0.005)
return # return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码 # if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return # return
data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据 # data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
result = self.calculate.queueOpt(data) # result = self.calculate.queueOpt(data)
if result is not None: # if result is not None:
self.zmqClient.send_to_all('result', int(result)) # self.zmqClient.send_to_all('result', int(result))
else: # 休息状态 # else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25: # if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005) # time.sleep(0.005)
return # return
self.thread_data_server.getData(25) # self.thread_data_server.getData(25)
#### Blink detection ##### #### Blink detection #####
def check_double_blink(self, current_time): # def check_double_blink(self, current_time):
""" # """
检查是否检测到连续两次眨眼 # 检查是否检测到连续两次眨眼
@param current_time: 当前眨眼时间戳 # @param current_time: 当前眨眼时间戳
@return: True表示检测到连续两次眨眼 # @return: True表示检测到连续两次眨眼
""" # """
if len(self.blink_timestamps) < 2: # if len(self.blink_timestamps) < 2:
return False # return False
# 检查是否在去抖期内 # # 检查是否在去抖期内
if self.last_double_blink_time > 0: # if self.last_double_blink_time > 0:
time_since_last_double_blink = current_time - self.last_double_blink_time # time_since_last_double_blink = current_time - self.last_double_blink_time
if time_since_last_double_blink < self.double_blink_jitter: # if time_since_last_double_blink < self.double_blink_jitter:
return False # 在去抖期内,忽略连续眨眼检测 # return False # 在去抖期内,忽略连续眨眼检测
last_time = self.blink_timestamps[-1] # 当前眨眼 # last_time = self.blink_timestamps[-1] # 当前眨眼
prev_time = self.blink_timestamps[-2] # 上次眨眼 # prev_time = self.blink_timestamps[-2] # 上次眨眼
interval = last_time - prev_time # interval = last_time - prev_time
if interval <= self.double_blink_interval: # if interval <= self.double_blink_interval:
return True # return True
return False # return False
def process_blink_detection(self): # def process_blink_detection(self):
""" # """
在缓冲区数据上执行,单次眨眼检测 # 在缓冲区数据上执行,单次眨眼检测
""" # """
if len(self.fp1_buffer) < self.window_samples: # if len(self.fp1_buffer) < self.window_samples:
return # return
fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:]) # fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:]) # fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
# 计算FP1和FP2的平均 # # 计算FP1和FP2的平均
fp12_mean = (fp1_data + fp2_data) / 2.0 # fp12_mean = (fp1_data + fp2_data) / 2.0
# 带通滤波 # # 带通滤波
try: # try:
fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean) # fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
except Exception as e: # except Exception as e:
print(f"Filter error: {e}") # print(f"Filter error: {e}")
return # return
F = np.diff(fp12_filtered) # F = np.diff(fp12_filtered)
if len(F) < 3: # if len(F) < 3:
return # return
b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax) # b, d, e = blink_detection(F, self.device_info['sample_rate'], self.Dmin, self.Dmax, self.EMin, self.EMax)
if b == 1: # if b == 1:
samples_since_last = self.total_samples - self.last_blink_time # samples_since_last = self.total_samples - self.last_blink_time
time_since_last_ms = (samples_since_last / self.fs) * 1000 # time_since_last_ms = (samples_since_last / self.device_info['sample_rate']) * 1000
if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms # if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
self.blink_count += 1 # self.blink_count += 1
self.last_blink_time = self.total_samples # self.last_blink_time = self.total_samples
current_time = time.time() # current_time = time.time()
self.blink_timestamps.append(current_time) # self.blink_timestamps.append(current_time)
blink_event = { # blink_event = {
'count': self.blink_count, # 'count': self.blink_count,
'time': current_time, # 'time': current_time,
'sample_index': self.total_samples, # 'sample_index': self.total_samples,
'duration_ms': d, # 'duration_ms': d,
'energy': e # 'energy': e
} # }
self.blink_events.append(blink_event) # self.blink_events.append(blink_event)
self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机 # self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
if self.check_double_blink(current_time): # if self.check_double_blink(current_time):
self.double_blink_count += 1 # self.double_blink_count += 1
interval = self.blink_timestamps[-1] - self.blink_timestamps[-2] # interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
double_blink_event = { # double_blink_event = {
'double_blink_count': self.double_blink_count, # 'double_blink_count': self.double_blink_count,
'blink1_time': self.blink_timestamps[-2], # 'blink1_time': self.blink_timestamps[-2],
'blink2_time': self.blink_timestamps[-1], # 'blink2_time': self.blink_timestamps[-1],
'interval': interval # 'interval': interval
} # }
self.double_blink_events.append(double_blink_event) # self.double_blink_events.append(double_blink_event)
self.last_double_blink_time = current_time # self.last_double_blink_time = current_time
self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件 # self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
def decoder_blink(self): # def decoder_blink(self):
if self.thread_data_server.GetDataLenCount() < 50: # if self.thread_data_server.GetDataLenCount() < 50:
time.sleep(0.005) # time.sleep(0.005)
return # return
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
data = self.thread_data_server.get_blinkData(50) # data = self.thread_data_server.get_blinkData(50)
fp1_data = data[0, :] # ch1 (相当于FP1) # fp1_data = data[0, :] # ch1 (相当于FP1)
fp2_data = data[1, :] # ch2 (相当于FP2) # fp2_data = data[1, :] # ch2 (相当于FP2)
for i in range(len(fp1_data)): # for i in range(len(fp1_data)):
self.fp1_buffer.append(fp1_data[i]) # self.fp1_buffer.append(fp1_data[i])
self.fp2_buffer.append(fp2_data[i]) # self.fp2_buffer.append(fp2_data[i])
self.total_samples += 1 # self.total_samples += 1
self.sample_counter += 1 # self.sample_counter += 1
if self.sample_counter >= self.step_samples: # if self.sample_counter >= self.step_samples:
self.process_blink_detection() # self.process_blink_detection()
self.sample_counter = 0 # self.sample_counter = 0
def stop(self): def stop(self):
''' '''

View File

@@ -63,8 +63,53 @@ class ParadigmRingBuffer:
获取最新缓存中每个通道的数量 获取最新缓存中每个通道的数量
@return: @return:
''' '''
return self.nUpdate return self.nUpdate
# ========== 各范式数据访问接口 ==========
def get_MIData(self):
"""获取MI导联数据 (21通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_SSMVEPData(self):
"""获取SSMVEP导联数据 (8通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getDataViaSSVEP(self, count):
"""获取SSVEP数据 (8通道 + 事件)"""
data = self.getData(count)
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_concentrateData(self, count):
"""获取专注力数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_blinkData(self, count):
"""获取眨眼数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
# reset buffer # reset buffer
def resetAllPara(self): def resetAllPara(self):
@@ -72,6 +117,4 @@ class ParadigmRingBuffer:
self.currentPtr = 0 self.currentPtr = 0
self.readPtr = 0 # add by lizhenhua 清空读指针 self.readPtr = 0 # add by lizhenhua 清空读指针
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区 self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区

View File

@@ -1,16 +1,22 @@
import ast
import numpy as np import numpy as np
import zmq
import threading import threading
import json import json
import queue import queue
from typing import Dict
# from Device.SunnyLinker import SunnyLinker64 # from Device.SunnyLinker import SunnyLinker64
from dataBuffer import ParadigmRingBuffer from Zmq.dataBuffer import ParadigmRingBuffer
from filterProcess import FilterRingBuffer from Zmq.filterProcess import FilterRingBuffer
from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log from logs.log import algo_log
import zmq
class zmqServer(threading.Thread): class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None): def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.device_info = device_info
self.host = host self.host = host
self.cmd_port = cmd_port # 命令交互端口 self.cmd_port = cmd_port # 命令交互端口
self.data_port = data_port # 数据接收端口 self.data_port = data_port # 数据接收端口
@@ -28,30 +34,29 @@ class zmqServer(threading.Thread):
self.daemon = True self.daemon = True
# 范式数据缓存 # 范式数据缓存
self.paradigmBuffer = ParadigmRingBuffer(66, 2500) self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
self.filterBuffer = FilterRingBuffer(66, 2500) self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
# 命令与数据通信 # 命令与数据通信
self.context = zmq.Context() self.context = zmq.Context()
# 指令通道 (8099) - ROUTER短JSON命令低频率 # 指令通道 (8099) - ROUTER短JSON命令低频率
self.cmd_socket = self.context.socket(zmq.ROUTER) self.cmd_socket = self.context.socket(zmq.ROUTER)
self.cmd_socket.setsockopt(zmq.RCVHWM, 100) # 指令不需要大缓存100条足够 # 通用套接字选项:仍在 SocketOption 中
self.cmd_socket.setsockopt(zmq.SNDHWM, 100) self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100)
self.cmd_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法降低指令延迟 self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100)
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}") self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
# 数据通道 (8100) - ROUTER高频脑电二进制流 # 数据通道 (8100) - ROUTER高频脑电二进制流
self.data_socket = self.context.socket(zmq.ROUTER) self.data_socket = self.context.socket(zmq.ROUTER)
self.data_socket.setsockopt(zmq.RCVHWM, 500) # 500包=10秒缓存足够应对短时卡顿 self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500)
self.data_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法减少数据传输延迟
self.data_socket.bind(f"tcp://{self.host}:{data_port}") self.data_socket.bind(f"tcp://{self.host}:{data_port}")
# Poller 轮训器(保持不变) # Poller 轮训器(保持不变)
self.poller = zmq.Poller() self.poller = zmq.Poller()
self.poller.register(self.cmd_socket, zmq.POLLIN) self.poller.register(self.cmd_socket, zmq.POLLIN)
self.poller.register(self.data_socket, zmq.POLLIN) self.poller.register(self.data_socket, zmq.POLLIN)
# 业务变量 # 业务变量
self.targetFreqs = [] self.targetFreqs = []
self.changeTarget = False # 更换目标频率 self.changeTarget = False # 更换目标频率
@@ -64,6 +69,77 @@ class zmqServer(threading.Thread):
self.cmd_clients = set() # 命令端口客户端ID self.cmd_clients = set() # 命令端口客户端ID
self.data_clients = set() # 数据端口客户端ID self.data_clients = set() # 数据端口客户端ID
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播) self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
# 范式buffer参数, 事件检测相关
self._event_lock = threading.Lock()
self._epoch_finished = False
self._event_inner_idx = -1
self.pack_contain_event = False
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.count_events = {}
self.latency = 50
self.train_latency = 50
self._interval_inited = False
@property
def interval_inited(self):
return self._interval_inited
@interval_inited.setter
def interval_inited(self, value):
self._interval_inited = value
@property
def epoch_finished(self):
with self._event_lock:
return self._epoch_finished
@epoch_finished.setter
def epoch_finished(self, value):
with self._event_lock:
self._epoch_finished = value
@property
def event_inner_idx(self):
with self._event_lock:
return self._event_inner_idx
@event_inner_idx.setter
def event_inner_idx(self, value):
with self._event_lock:
self._event_inner_idx = value
def interval_init(self, decoder_class):
if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
self.latency = (self.interval_epoch[
1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
self.train_epoch = self.interval_epoch.copy()
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;
self.train_latency = self.latency
print('时间窗:', (interval_epoch))
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
self.event_inner_idx = -1 # event在5位数据包内部的idx
self.epoch_finished = False # 接收epoch是否完整
self.pack_contain_event = False # 当前包是否含有event
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.interval_inited = True
# if getattr(self, 'serial', None) and self.serial.is_open:
# self.serial.close()
# self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口
def broadcast_message(self, method, params): def broadcast_message(self, method, params):
"""Put message into queue to be sent to all command clients""" """Put message into queue to be sent to all command clients"""
@@ -78,15 +154,15 @@ class zmqServer(threading.Thread):
# 注册新的命令客户端 # 注册新的命令客户端
if ident not in self.cmd_clients: if ident not in self.cmd_clients:
self.cmd_clients.add(ident) self.cmd_clients.add(ident)
print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})") algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
# 解析消息 # 解析消息
try: try:
message = json.loads(message_bytes.decode('utf-8')) message = json.loads(message_bytes.decode('utf-8'))
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"Invalid JSON from CMD client {ident}") algo_log(f"Invalid JSON from CMD client {ident}")
continue return
print(f"Received CMD request: {message}") algo_log(f"Received CMD request: {message}")
method = message.get("method") method = message.get("method")
params = message.get("params") params = message.get("params")
@@ -94,37 +170,40 @@ class zmqServer(threading.Thread):
# 原有命令处理逻辑 # 原有命令处理逻辑
if method == "sync": if method == "sync":
self.state_mode = 'sync' self.state_mode = 'sync'
if method == "targetFreqs": elif method == "targetFreqs":
if not isinstance(params, list): if not isinstance(params, list):
print('targetFreqs must be a list') algo_log(f"targetFreqs must be a list")
continue return
if params != self.targetFreqs: if params != self.targetFreqs:
self.targetFreqs = params self.targetFreqs = params
self.changeTarget = True self.changeTarget = True
if method == "decoderClass": elif method == "decoderClass":
if not isinstance(params, str): if not isinstance(params, str):
print('decoderClass must be a str') algo_log(f"decoderClass must be a str")
continue return
if params != self.decoder_class: if params != self.decoder_class:
self.decoder_class = params self.decoder_class = params
self.decoder_switch = True self.decoder_switch = True
if method == "getReport": elif method == "train":#训练状态
self.getReport = True
if method == "train":#训练状态
self.state_mode = 'train' self.state_mode = 'train'
self.StartTrain = True self.StartTrain = True
self.currentLabel = params # 当前刺激端的训练标签 self.currentLabel = params # 当前刺激端的训练标签
self.sunnyLinker.push_trigger(self.labels[self.currentLabel]) # self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
elif method == "predict":#预测状态 elif method == "predict":#预测状态
self.state_mode = 'predict' self.state_mode = 'predict'
if params == 1: #开始解码 if params == 1: #开始解码
self.StartDecode = True self.StartDecode = True
self.sunnyLinker.push_trigger(0x63) # self.sunnyLinker.push_trigger(0x63)
elif params == 2: #停止解码 elif params == 2: #停止解码
self.IsExitApp = True self.IsExitApp = True
self.running = False self.running = False
elif method == "rest": #休息状态 elif method == "rest": #休息状态
self.state_mode = 'rest' self.state_mode = 'rest'
else:
algo_log(f"未知命令:{method}", level="WARNING")
# elif method == "getReport":
# self.getReport = True
# elif method == "impedance": # elif method == "impedance":
# if params == 1: # if params == 1:
# self.open_Impedance = True # 开启阻抗 # self.open_Impedance = True # 开启阻抗
@@ -138,47 +217,48 @@ class zmqServer(threading.Thread):
处理8100端口原始脑电二进制数据 处理8100端口原始脑电二进制数据
固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区 固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区
""" """
# 1. 校验ZMQ消息帧完整性 # 1. 校验ZMQ消息帧完整性ROUTER接收DEALER消息的帧格式[客户端ID, 发送方ID, 空帧, 数据帧]
if len(frames) < 3: if len(frames) < 4: # 至少需要4帧
print(f"[ERROR] 无效数据帧长度不足3帧实际长度={len(frames)}") algo_log(f"Invalid data frame: 帧数量不足期望≥4实际{len(frames)}", level="ERROR")
return return
ident, _, data_bytes = frames[:3] # 2. 正确解析帧适配DEALER→ROUTER的帧格式
client_ident, sender_ident, empty_sep, data_bytes = frames[:4]
if empty_sep != b'': # 校验空分隔帧
algo_log(f"Invalid frame separator: 期望空字节,实际{empty_sep}", level="ERROR")
return
# 2. 客户端管理(单客户端场景,自动更新最新身份) # 3. 客户端管理(单客户端场景,自动更新最新身份)
if ident not in self.data_clients: if client_ident not in self.data_clients:
self.data_clients.add(ident) self.data_clients.add(client_ident)
self.current_data_client = ident # 保存唯一客户端身份,用于后续回复滤波结果 self.current_data_client = client_ident # 保存唯一客户端身份,用于后续回复滤波结果
print(f"[INFO] 新数据客户端连接成功:{ident}") print(f"[INFO] 新数据客户端连接成功:{client_ident}")
try: try:
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节与int32字节数相同 # 4. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节
EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节 EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节
if len(data_bytes) != EXPECTED_BYTES: if len(data_bytes) != EXPECTED_BYTES:
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节") algo_log(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
return return
# 4. 零拷贝二进制解析 + 维度转换 # 5. 零拷贝二进制解析 + 维度转换
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
data_np = np.frombuffer(data_bytes, dtype=np.float32) data_np = np.frombuffer(data_bytes, dtype=np.float32)
# 重塑为上位机原始维度 data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
data_np = data_np.reshape(5, 66)
# 转置为(通道数, 采样点数)标准格式转换为float64保证滤波运算精度
data_np = data_np.T.astype(np.float64) data_np = data_np.T.astype(np.float64)
# 5. 同时写入双环形缓冲区方法名与现有类保持一致appendBuffer # 6. 写入缓冲区
# 注意:上位机已发送微伏物理量,无需再乘以增益系数
self.paradigmBuffer.appendBuffer(data_np) self.paradigmBuffer.appendBuffer(data_np)
self.filterBuffer.appendBuffer(data_np) self.filterBuffer.appendBuffer(data_np)
# 生产环境必须注释每秒50次打印会导致CPU占用飙升30%以上 algo_log(f"数据写入成功shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG")
algo_log(f"数据写入成功shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG", record_once=True)
except Exception as e: except Exception as e:
algo_log(f"数据处理失败:{str(e)}", level="ERROR") algo_log(f"数据处理失败:{str(e)}", level="ERROR")
# 调试阶段临时打开,生产环境务必注释 if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG':
import traceback import traceback
traceback.print_exc() traceback.print_exc()
def _process_send_queue(self): def _process_send_queue(self):
"""处理发送队列,向所有命令客户端广播消息""" """处理发送队列,向所有命令客户端广播消息"""
@@ -207,15 +287,15 @@ class zmqServer(threading.Thread):
def run(self): def run(self):
self.running = True self.running = True
print(f"ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}") algo_log(f"algo ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}", level="INFO")
try: try:
while self.running: while self.running:
# 1. 处理发送队列(命令端口广播) # 1. 处理发送队列(命令端口广播)
self._process_send_queue() self._process_send_queue()
# 2. 轮训监听两个Socket的输入事件10ms超时避免阻塞 # 2. 轮训监听两个Socket的输入事件
socks = dict(self.poller.poll(10)) socks = dict(self.poller.poll(50))
# 处理命令端口消息 # 处理命令端口消息
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN: if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:

View File

@@ -1,445 +0,0 @@
import numpy as np
import zmq
import threading
import json
import queue
import time
from Device.SunnyLinker import SunnyLinker64, RingBuffer
from collections import deque
class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100):
threading.Thread.__init__(self)
self.host = host
self.cmd_port = cmd_port
self.data_port = data_port
self.running = False
self.get_Impedance = False
self.open_Impedance = None
self.StartDecode = False
self.StartTrain = False
self.state_mode = None
self.currentLabel = -1
self.IsExitApp = False
self.getReport = False
self.daemon = True
# ZMQ Context
self.context = zmq.Context()
# 指令通道 (8099) - ROUTER
self.cmd_socket = self.context.socket(zmq.ROUTER)
self.cmd_socket.setsockopt(zmq.RCVHWM, 1000)
self.cmd_socket.setsockopt(zmq.SNDHWM, 1000)
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
# 数据通道 (8100)) - ROUTER
self.data_socket = self.context.socket(zmq.ROUTER)
self.data_socket.setsockopt(zmq.RCVHWM, 1000)
self.data_socket.setsockopt(zmq.RCVTIMEO, 50)
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
self.targetFreqs = []
self.changeTarget = False
self.sunnyLinker = SunnyLinker64(None, None, None, None, None)
self.labels = [0x01, 0x02, 0x03]
self.decoder_switch = False
self.decoder_class = None
self.cmd_clients = set()
self.data_clients = set()
self.send_queue = queue.Queue()
# ========== 数据缓冲区 (RingBuffer) ==========
# 与 SunnyLinker 保持一致,使用 RingBuffer
# 66 = 64 EEG通道 + 1 事件通道(第65) + 1 标签序号通道(第66)
# 缓存约 10 秒数据 (250Hz * 10s = 2500 点)
self.n_chan = 66
self.t_buffer = 10.0 # 缓冲区时长(秒)
self.__ringBuffer = RingBuffer(self.n_chan, int(self.t_buffer * 250))
# 事件检测相关
self._event_lock = threading.Lock()
self._epoch_finished = False
self._event_inner_idx = -1
self.pack_contain_event = False
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.count_events = {}
self.latency = 50
self.train_latency = 50
# 当前事件标签序号 (从第66通道获取)
self.current_label_index = 0
# 初始化标志
self._interval_inited = False
self._currentLabel = -1
# 注册的客户端(兼容旧接口)
self.clients = set()
# ========== 事件属性:线程安全访问 ==========
@property
def epoch_finished(self):
with self._event_lock:
return self._epoch_finished
@epoch_finished.setter
def epoch_finished(self, value):
with self._event_lock:
self._epoch_finished = value
@property
def event_inner_idx(self):
with self._event_lock:
return self._event_inner_idx
@event_inner_idx.setter
def event_inner_idx(self, value):
with self._event_lock:
self._event_inner_idx = value
@property
def interval_inited(self):
return self._interval_inited
@interval_inited.setter
def interval_inited(self, value):
self._interval_inited = value
@property
def currentLabel(self):
return self._currentLabel
@currentLabel.setter
def currentLabel(self, value):
self._currentLabel = value
def broadcast_message(self, method, params):
"""Put message into queue to be sent to all connected clients"""
self.send_queue.put((method, params))
# ========== 数据缓冲区操作接口 ==========
def GetDataLenCount(self):
"""返回缓冲区当前数据点数"""
return self.__ringBuffer.nUpdate
def getData(self, count):
"""获取最新count个数据点不消费只读"""
with self.__ringBuffer.RingBufferLock:
count = min(count, self.__ringBuffer.nUpdate)
if count == 0:
return np.zeros((self.n_chan, 0))
# 计算读取范围(从尾部取最新数据)
read_end = (self.__ringBuffer.currentPtr - 1) % self.__ringBuffer.n_points
read_start = (read_end - count + 1) % self.__ringBuffer.n_points
if self.__ringBuffer.currentPtr == 0:
read_start = self.__ringBuffer.n_points - count
read_end = self.__ringBuffer.n_points - 1
if read_start <= read_end:
data = self.__ringBuffer.buffer[:, read_start:read_end + 1]
else:
part1 = self.__ringBuffer.buffer[:, read_start:]
part2 = self.__ringBuffer.buffer[:, :read_end + 1]
data = np.concatenate((part1, part2), axis=1)
return data
def consumeData(self, count):
"""消费(丢弃)指定数量的数据点,从头部移除"""
with self.__ringBuffer.RingBufferLock:
count = min(count, self.__ringBuffer.nUpdate)
self.__ringBuffer.readPtr = (self.__ringBuffer.readPtr + count) % self.__ringBuffer.n_points
self.__ringBuffer.nUpdate -= count
def ResetAll(self):
"""重置缓冲区"""
with self.__ringBuffer.RingBufferLock:
self.__ringBuffer.resetAllPara()
with self._event_lock:
self._epoch_finished = False
self._event_inner_idx = -1
self.pack_contain_event = False
self.count_events.clear()
self.current_label_index = 0
def reset_data_buffer(self):
self.ResetAll()
def reset_state(self):
self.ResetAll()
def interval_init(self, decoder_class):
"""初始化事件检测参数"""
import ast
from PubLibrary.InifileHelper import IniRead
if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.interval_epoch = [int(i * 250) for i in interval_epoch]
self.train_epoch = [int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * 250)]
self.latency = (self.interval_epoch[1] + 0.1 * 250) // 5
self.train_latency = (self.train_epoch[1] + 0.1 * 250) // 5
elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.interval_epoch = [int(i * 250) for i in interval_epoch]
self.train_epoch = self.interval_epoch.copy()
self.latency = self.interval_epoch[1] // 5
self.train_latency = self.latency
self.count_events = {}
self._event_inner_idx = -1
self._epoch_finished = False
self.pack_contain_event = False
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self._interval_inited = True
# ========== 事件检测 ==========
def detect_event(self, data_matrix):
"""
检测事件通道中的触发信号
@param data_matrix: shape (66, N) - N个采样点的数据
第65行(索引64) = 事件通道
第66行(索引65) = 标签通道
@return: 是否检测到事件
"""
if data_matrix.shape[1] == 0:
return False
self.pack_contain_event = False
event_channel = data_matrix[64, :] # 第65通道 = 标签值(event值)
label_channel = data_matrix[65, :] # 第66通道 = 标签序号(label index)
events = event_channel.tolist()
with self._event_lock:
self._event_inner_idx = -1
self.current_event_label = 0
for idx, event in enumerate(events):
if int(event) in self.events:
self._event_inner_idx = idx
self.current_label_index = int(label_channel[idx])
self.pack_contain_event = True
new_key = f"{event}_{time.time()}"
latency = self.latency if event == self.predict_event else self.train_latency
self.count_events[new_key] = latency + 1
# 延迟计数递减
drop_items = []
for key, value in self.count_events.items():
value = value - 1
if value == 0:
drop_items.append(key)
self.count_events[key] = value
for key in drop_items:
del self.count_events[key]
if drop_items:
self._epoch_finished = True
# 检测到事件时清除RingBuffer中之前的数据只保留当前包
if self.pack_contain_event:
self.__ringBuffer.resetAllPara()
return True
self._epoch_finished = False
return False
def run(self):
self.running = True
print(f"Server running - CMD: {self.cmd_port}, DATA: {self.data_port}")
cmd_poller = zmq.Poller()
cmd_poller.register(self.cmd_socket, zmq.POLLIN)
data_poller = zmq.Poller()
data_poller.register(self.data_socket, zmq.POLLIN)
try:
while self.running:
# --- 处理发送队列 (指令通道) ---
while not self.send_queue.empty():
method, params = self.send_queue.get()
if self.cmd_clients:
try:
msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8')
for client_id in list(self.cmd_clients):
try:
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
except Exception:
pass
except Exception:
pass
# --- 处理指令通道 ---
socks = dict(cmd_poller.poll(10))
if self.cmd_socket in socks:
self._handle_cmd_socket()
# --- 处理数据通道 ---
socks = dict(data_poller.poll(10))
if self.data_socket in socks:
self._handle_data_socket()
except Exception as e:
print(f"Server error: {e}")
finally:
self.running = False
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
def _handle_cmd_socket(self):
"""处理指令通道消息"""
try:
frames = self.cmd_socket.recv_multipart()
if len(frames) < 3:
return
ident, _, message_bytes = frames[:3]
self.cmd_clients.add(ident)
self.clients.add(ident)
message = json.loads(message_bytes.decode('utf-8'))
method = message.get("method")
params = message.get("params")
print(f"[CMD] {method}: {params}")
if method == "sync":
self.state_mode = 'sync'
elif method == "targetFreqs":
if isinstance(params, list) and params != self.targetFreqs:
self.targetFreqs = params
self.changeTarget = True
elif method == "decoderClass":
if isinstance(params, str) and params != self.decoder_class:
self.decoder_class = params
self.decoder_switch = True
elif method == "getReport":
self.getReport = True
elif method == "train":
self.state_mode = 'train'
self.StartTrain = True
self.currentLabel = params
elif method == "predict":
self.state_mode = 'predict'
if params == 1:
self.StartDecode = True
elif params == 2:
self.IsExitApp = True
self.running = False
elif method == "rest":
self.state_mode = 'rest'
elif method == "impedance":
if params == 1:
self.open_Impedance = True
self.get_Impedance = True
elif params == 2:
self.open_Impedance = False
self.get_Impedance = False
except Exception as e:
print(f"CMD socket error: {e}")
def _handle_data_socket(self):
"""处理数据通道消息 (EEG数据)
上位机数据格式:
- 数据帧: [identity, '', meta_json, data_buffer]
data_buffer = [N, 66] float32 -> 转置为 [66, N]
"""
try:
frames = self.data_socket.recv_multipart()
if len(frames) < 4:
return
ident, _, message_bytes = frames[:3]
self.data_clients.add(ident)
meta = json.loads(message_bytes.decode('utf-8'))
# data: [N, 66] -> 转置 -> [66, N]
raw_data = np.frombuffer(frames[3], dtype=np.float32)
n_samples, n_channels = meta.get('shape', [5, 66])
data_matrix = raw_data.reshape(n_samples, n_channels).T.astype(np.float32)
# 写入 RingBuffer
with self.__ringBuffer.RingBufferLock:
self.__ringBuffer.appendBuffer(data_matrix)
# 事件检测
self.detect_event(data_matrix)
except Exception as e:
print(f"DATA socket error: {e}")
# ========== 各范式数据访问接口 ==========
def get_MIData(self):
"""获取MI导联数据 (21通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_SSMVEPData(self):
"""获取SSMVEP导联数据 (8通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getDataViaSSVEP(self, count):
"""获取SSVEP数据 (8通道 + 事件)"""
data = self.getData(count)
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_concentrateData(self, count):
"""获取专注力数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_blinkData(self, count):
"""获取眨眼数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getImpedance(self, data, decoder_class):
"""计算阻抗ZMQ模式下不可用"""
return np.zeros(8)
def stop(self):
self.running = False
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
if __name__ == '__main__':
server = zmqServer()
server.start()

View File

@@ -8,6 +8,7 @@ import os
# import logging # import logging
import base64 import base64
import io import io
import math
# logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
# #
@@ -22,7 +23,7 @@ import io
class Calculate(): class Calculate():
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10): def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10, config=None):
self.Threshold_value_low = Threshold_value_low self.Threshold_value_low = Threshold_value_low
self.Threshold_value_high = Threshold_value_high self.Threshold_value_high = Threshold_value_high
self.fs = fs self.fs = fs
@@ -30,48 +31,74 @@ class Calculate():
self.CLI_result = [] self.CLI_result = []
self.EVI_result = [] self.EVI_result = []
self.eegQueue = deque(maxlen=win_len) self.eegQueue = deque(maxlen=win_len)
# # 存储历史数据用于绘图
# self.beta_history = []
# self.alpha_history = []
# self.theta_history = []
# self.focus_history = []
# self.timestamp_history = []
#
# # 记录开始时间
# self.start_time = None
# self.recording = False
#
# # 图表保存路径
# self.chart_dir = "reports"
# if not os.path.exists(self.chart_dir):
# os.makedirs(self.chart_dir)
# print(f"[调试] 创建目录: {self.chart_dir}")
# 初始化滤波器 # 初始化滤波器
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False) self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
self.last_focus = None
# 异步滤波系数配置(核心手感控制纽)
self.alpha_up = 1 # 上升系数:较小,保证分数平滑爬升,过滤偶发的瞬时高能量
# alpha_down / shrink_factor 从 config.ini 读取,方便上位机调参
if config:
self.alpha_down = float(config.get('alpha_down', 0.8))
self.shrink_factor = float(config.get('shrink_factor', 0.5))
else:
self.alpha_down = 0.8
self.shrink_factor = 0.5
print("[调试] Calculate 类初始化完成") print("[调试] Calculate 类初始化完成")
def calculate_focus(self, beta, alpha, theta): def calculate_focus(self, beta, alpha, theta):
""" """
专注度计算 - 固定映射版本 专注度计算 - 三区间门限异步滤波版本
""" """
# 0. 频带特征预处理
theta_mod = theta ** 0.7
# 原始比值 # 原始比值
raw = beta / (alpha + theta + 1e-10) raw = beta / (alpha + theta_mod + 1e-10)
# Sigmoid 映射:让 raw 在 0.3-1.5 区间敏感 exponent = 2.0
# 参数可调:
# k = 12 (斜率,越大越陡) # 1. 防止脑电比值出现负数异常值
# x0 = 0.6 (中心点raw=0.6时focus≈50) raw_input = max(raw, 0.0)
k = 12.0
x0 = 0.6 # 2. 2次幂纵轴压缩映射 (shrink_factor 从 config.ini 读取)
focus = 100.0 / (1.0 + np.exp(-k * (raw - x0))) focus_raw = 100 * self.shrink_factor * (raw_input ** exponent)
# 可选:添加滑动平均平滑 # 3. 计算当前帧的瞬时分数 (基准量级 0-120)
instant_focus = 120 * (1.0 - np.exp(-focus_raw / 100.0))
# 4. 核心修改:三区间门限时域滤波
if self.last_focus is None:
# 冷启动:首帧直接赋值
focus = instant_focus
else:
# 判断当前瞬时分数是否处于【极端区】(80以上 或 60以下)
if instant_focus > 85.0 or instant_focus < 60.0:
# 执行异步低通时域滤波
if instant_focus >= self.last_focus:
# 趋势上升:慢爬升
focus = self.alpha_up * instant_focus + (1 - self.alpha_up) * self.last_focus
else:
# 趋势下降:快跌落
focus = self.alpha_down * instant_focus + (1 - self.alpha_down) * self.last_focus
else:
# 【高灵敏自由区】(60 <= instant_focus <= 80)
# 不执行异步滤波,分数直接跟随瞬时值,保证中间状态绝对跟手
focus = instant_focus
# 5. 更新历史状态缓存
self.last_focus = focus
# 打印在线调试日志,方便观察区间切换
zone_tag = "极端区(滤波)" if (instant_focus > 80 or instant_focus < 60) else "自由区(直通)"
print(f"原始特征比值 raw: {raw:.4f} | 瞬时分数: {instant_focus:.1f} | 滤波后分数: {focus:.1f}")
# 最终返回整型
return int(focus) return int(focus)
def calculate_all(self, data, fs, nperseg=1000): def calculate_all(self, data, fs, nperseg=1000):
mean_x = np.mean(data, axis=-1, keepdims=True) mean_x = np.mean(data, axis=-1, keepdims=True)
data = data - mean_x data = data - mean_x
@@ -90,7 +117,7 @@ class Calculate():
if len(self.focus_result) > 3: if len(self.focus_result) > 3:
self.focus_result.pop(0) self.focus_result.pop(0)
final_focus = int(self.simple_moving_average(self.focus_result, window_size=5)) final_focus = int(self.simple_moving_average(self.focus_result, window_size=5))
cli_denom = alpha_psd + beta_psd cli_denom = alpha_psd + beta_psd
CLI_score = np.log(theta_psd / (cli_denom + 1e-10)) if cli_denom > 0 else 0 CLI_score = np.log(theta_psd / (cli_denom + 1e-10)) if cli_denom > 0 else 0
self.CLI_result.append(CLI_score) self.CLI_result.append(CLI_score)
@@ -319,14 +346,16 @@ class Calculate():
if eegData.size == 0: if eegData.size == 0:
return None return None
eegData -= np.mean(eegData, axis=-1, keepdims=True) eegData -= np.mean(eegData, axis=-1, keepdims=True)
eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) # eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) # 陷波
eegData = signal.lfilter(self.b_design, 1, eegData) # eegData = signal.lfilter(self.b_design, 1, eegData) # 滤波
focus_score, CLI_score, beta, alpha, theta = self.calculate_all(eegData, fs=self.fs, nperseg=1000) focus_score, CLI_score, beta_psd, alpha_psd, theta_psd = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
# self.add_data_point(focus_score, beta, alpha, theta) # self.add_data_point(focus_score, beta_psd, alpha_psd, theta_psd) # 已注释(方法已移除)
return focus_score # return (focus_score)
return None return (focus_score, beta_psd)
# return None
class Calculate2(): class Calculate2():

View File

@@ -19,10 +19,10 @@ Serial_port = COM44
algo_log_level = DEBUG algo_log_level = DEBUG
console_output = 1 console_output = 1
; 64 导设备配置
; 64 导设备配置 1; 32 2; 24 3; 16 4; 8 5; 4 6; [device_type_1]
[device_type] = 1 sample_rate = 250
device_sample_rate = 250 frame_points = 5
device_channel_nums = 66 channel_nums = 66
device_channel_names = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1', 'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6', 'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1', 'T7', 'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ', 'PZ', 'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1', 'label', 'label_tag'] channel_names = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1', 'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6', 'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1', 'T7', 'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ', 'PZ', 'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1', 'label', 'label_tag']
device_channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65] channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]

138
datamock.py Normal file
View File

@@ -0,0 +1,138 @@
import zmq
import numpy as np
import time
from datetime import datetime
# ========== 参数配置 ==========
FS = 250 # 采样率 Hz
N_SAMPLES_PER_PKT = 5 # 每包采样点数
N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
EEG_FREQ = 10 # EEG 正弦波频率 Hz
EEG_AMP = 100.0 # EEG 幅值 100μV
LABEL_INTERVAL = 5 # 标签间隔秒数
SERVER_ADDR = 'tcp://127.0.0.1:8100'
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
def build_packet(global_sample_idx):
"""
生成一包 [5, 66] 的 float32 数据
:param global_sample_idx: 当前包第一个采样点在全局序列中的索引 (从 0 开始)
:return: np.ndarray shape [5, 66]
"""
# 当前包内 5 个采样点对应的时间(秒)
t = (global_sample_idx + np.arange(N_SAMPLES_PER_PKT)) / FS
# Ch0-63: EEG 10Hz 正弦波,幅值 100μV
# t shape [5,]sin 乘以标量后仍是 [5,],需要 reshape 为 [5,1] 再广播到 64 通道
eeg = (EEG_AMP * np.sin(2 * np.pi * EEG_FREQ * t)).reshape(N_SAMPLES_PER_PKT, 1) # [5, 1]
eeg = np.tile(eeg, (1, 64)) # [5, 64]
# Ch64: 标签值通道,初始化为 0
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float32)
# Ch65: 标签序号通道,初始化为 0
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float32)
# 拼成 [5, 66]
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float32)
return packet
def should_send_label(global_sample_idx):
"""
判断当前包是否包含标签触发点(每 5s 的最后一个采样点)
采样点索引从 0 开始,每 5s = 1250 个采样点
最后一个采样点索引: 1249, 2499, 3749, ...
由于每包 5 个采样点,标签点落在包内的最后一个采样点位置
即当前包起始索引 global_sample_idx 必须使得:
global_sample_idx <= 标签点索引 < global_sample_idx + N_SAMPLES_PER_PKT
也就是 global_sample_idx <= 1249 < global_sample_idx + 5
即 global_sample_idx = 1245, 2495, 3745, ...
即 global_sample_idx = n * LABEL_INTERVAL * FS - N_SAMPLES_PER_PKT
"""
samples_per_interval = LABEL_INTERVAL * FS
# 检查当前包是否包含 interval 的最后一个采样点
# 标签点索引 = n * 1250 - 1当 global_sample_idx = n*1250-5 时,标签在包内索引 4
return (global_sample_idx + N_SAMPLES_PER_PKT - 1) % samples_per_interval == samples_per_interval - 1
def main():
ctx = zmq.Context()
sock = ctx.socket(zmq.DEALER)
sock.connect(SERVER_ADDR)
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}")
global_sample_idx = 0 # 全局采样点计数器
label_type = 1 # 当前标签类型: 1 或 2
label1_count = 0 # label=1 的序号计数器
label2_count = 0 # label=2 的序号计数器
packet_count = 0 # 已发送包数
print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...")
print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms")
print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV")
print(f" 标签: 每 {LABEL_INTERVAL}s 末尾采样点触发 | label 1/2 交替")
print("-" * 50)
try:
while True:
t_start = time.perf_counter()
# 构建当前包
packet = build_packet(global_sample_idx)
# 检查是否需要放置标签
if should_send_label(global_sample_idx):
if label_type == 1:
label1_count += 1
label_value = 1
label_number = label1_count
else:
label2_count += 1
label_value = 2
label_number = label2_count
# 标签放在当前包最后一个采样点(索引 4
packet[4, 64] = label_value
packet[4, 65] = label_number
ts = datetime.now().strftime('%H:%M:%S')
print(f"[{ts}] 标签触发: label={label_value}, 序号={label_number} "
f"(global_sample_idx={global_sample_idx})")
# 交替标签类型
label_type = 2 if label_type == 1 else 1
# 发送: multipart 3帧 [identity, '', data]
sock.send_multipart([
b'datamock',
b'',
packet.tobytes()
])
# 每 50 包打印一次进度
if packet_count % 50 == 0:
ts = datetime.now().strftime('%H:%M:%S')
print(f"[{ts}] 已发送 {packet_count} 包 (global_sample_idx={global_sample_idx})")
global_sample_idx += N_SAMPLES_PER_PKT
packet_count += 1
# 精确控制发送节奏: 等待到 PKT_INTERVAL 秒
elapsed = time.perf_counter() - t_start
sleep_time = PKT_INTERVAL - elapsed
if sleep_time > 0:
time.sleep(sleep_time)
except KeyboardInterrupt:
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 停止发送,共发送 {packet_count}")
finally:
sock.close()
ctx.term()
if __name__ == '__main__':
main()

View File

@@ -1,79 +1,71 @@
# log.py
import os import os
from datetime import datetime from datetime import datetime
import logging import logging
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
import inspect # 新增导入
from PubLibrary.InifileHelper import IniRead from PubLibrary.InifileHelper import IniRead
console_output = IniRead('system', 'console_output', '1') console_output = IniRead('system', 'console_output', '1')
log_level = IniRead('system', 'algo_log_level', 'INFO') log_level = IniRead('system', 'algo_log_level', 'INFO')
# 新增日志去重缓存key为日志内容value为是否已打印
log_once_cache = set() log_once_cache = set()
# 缓存已经创建过的logger避免重复创建handler
logger_cache = {}
def init_module_logger(): def init_module_logger(logger_name):
""" log_dir = './logs/'
初始化指定模块的日志器
:return: 对应模块的logger实例
"""
# 缓存命中则直接返回
log_dir = './logs/' # 确保日志目录存在
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log') log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log')
# 初始化logger # 已创建直接返回
logger = logging.getLogger('decoderLogger') if logger_name in logger_cache:
return logger_cache[logger_name]
logger = logging.getLogger(logger_name)
logger.setLevel(log_level) logger.setLevel(log_level)
if logger.handlers: if logger.handlers:
return logger logger_cache[logger_name] = logger
return logger
# 设置日志轮转最大10个文件每个10MB
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(
log_file, log_file,
maxBytes=10*1024*1024, maxBytes=10*1024*1024,
backupCount=10, backupCount=10,
encoding='utf-8' encoding='utf-8'
) )
# 日志格式
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s', '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S' datefmt='%Y-%m-%d %H:%M:%S'
) )
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
logger.setLevel(log_level)
logger.addHandler(file_handler) logger.addHandler(file_handler)
if console_output: if console_output:
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
logger.addHandler(console_handler) logger.addHandler(console_handler)
logger_cache[logger_name] = logger
return logger return logger
def algo_log(content, level="INFO", record_once=False): def algo_log(content, level="INFO", record_once=False):
""" # 向上回溯1层栈拿到调用algo_log的代码文件信息
通用日志函数,支持按模块输出到不同日志文件 frame = inspect.currentframe().f_back
:param content: 日志内容 file_path = frame.f_code.co_filename
:param level: 日志级别DEBUG/INFO/WARNING/ERROR/FATAL # 提取py文件名不带后缀/带后缀自选)
:param record_once: 是否只打印一次该日志内容默认False file_name = os.path.basename(file_path) # 例zmqServer.py
""" # file_name = os.path.splitext(os.path.basename(file_path))[0] # 例zmqServer
# 初始化模块日志器
logger = init_module_logger() logger = init_module_logger(file_name)
# 新增:处理只打印一次的逻辑
if record_once: if record_once:
# 生成唯一标识可根据需要调整比如拼接level增强唯一性
log_key = f"{level.upper()}_{content}" log_key = f"{level.upper()}_{content}"
if log_key in log_once_cache: if log_key in log_once_cache:
return # 已打印过,直接返回 return
log_once_cache.add(log_key) # 未打印过,加入缓存 log_once_cache.add(log_key)
# 根据级别输出日志
level_upper = level.upper() level_upper = level.upper()
if level_upper == "DEBUG": if level_upper == "DEBUG":
logger.debug(content) logger.debug(content)
@@ -83,5 +75,5 @@ def algo_log(content, level="INFO", record_once=False):
logger.error(content) logger.error(content)
elif level_upper == "FATAL": elif level_upper == "FATAL":
logger.fatal(content) logger.fatal(content)
else: # 默认INFO级别 else:
logger.info(content) logger.info(content)

View File

@@ -6,32 +6,33 @@ import time
from Decoder import Decoder_main from Decoder import Decoder_main
from PubLibrary.RunOnce import is_program_running from PubLibrary.RunOnce import is_program_running
from PubLibrary.InifileHelper import IniRead from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log
def get_device_info(device_type): def get_device_info(device_type):
section = f'device_type_{device_type}' section = f'device_type_{device_type}'
device_info = { device_info = {
'device_sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250, 'sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250,
'frame_points': int(IniRead(section, 'frame_points')) if IniRead(section, 'frame_points') is not None else 5,
'' 'channel_nums': int(IniRead(section, 'channel_nums')) if IniRead(section, 'channel_nums') is not None else 66,
'channel_names': IniRead(section, 'channel_names') if IniRead(section, 'channel_names') is not None else None,
'channel_index': IniRead(section, 'channel_index') if IniRead(section, 'channel_index') is not None else None,
} }
return device_info
if __name__ == "__main__": if __name__ == "__main__":
if not is_program_running(): if not is_program_running():
# 解析命令行参数 # 解析命令行参数
parser = argparse.ArgumentParser(description="EEG Decoder Application") # parser = argparse.ArgumentParser(description="EEG Decoder Application")
parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type") # parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type")
# parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP") # parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
# parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port") # parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
# parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP") # parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
# parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port") # parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
# args = parser.parse_args()
args = parser.parse_args()
device_info= get_device_info(args.device_type)
decoder = Decoder_main(device_info=device_info)
# decoder.connect( # decoder.connect(
# device_type=args.device_type, # device_type=args.device_type,
# device_host=args.device_host, # device_host=args.device_host,
@@ -40,6 +41,10 @@ if __name__ == "__main__":
# upper_port=args.upper_port # upper_port=args.upper_port
# ) # )
device_info= get_device_info(1)
algo_log(f"device_info: {device_info}", level="INFO")
decoder = Decoder_main(device_info=device_info)
try: try:
decoder.start() decoder.start()
while not decoder.zmqServer.IsExitApp: while not decoder.zmqServer.IsExitApp: