Compare commits

...

11 Commits

Author SHA1 Message Date
Ivey Song
30c690e4e3 Merge branch 'master' of http://47.98.56.110:7001/lizhao/bci_algo 2026-06-06 16:08:31 +08:00
Ivey Song
853037726d update 2026-06-06 16:05:32 +08:00
8a9d9a5c78 update log 2026-06-06 16:04:33 +08:00
29b6118f11 v1 2026-06-06 15:53:50 +08:00
fce7d93d5e delete server1 2026-06-06 15:15:10 +08:00
949801198e update 2026-06-06 15:13:23 +08:00
Ivey Song
494515463d 专注力计算 2026-06-06 14:57:52 +08:00
Ivey Song
9a655ffdeb 删除以往pth模型 2026-06-06 14:51:00 +08:00
Ivey Song
a9fd51e935 brainmap 2026-06-06 14:50:16 +08:00
Ivey Song
4b7e48be38 数据模拟 2026-06-06 14:49:38 +08:00
2d190d6431 add buffer 2026-06-06 14:40:07 +08:00
9 changed files with 692 additions and 968 deletions

View File

@@ -1,4 +1,7 @@
import ast import ast
import glob
import os
import sys
import threading import threading
from datetime import datetime from datetime import datetime
import multiprocessing as mp import multiprocessing as mp
@@ -8,7 +11,7 @@ import torch
from queue import Empty from queue import Empty
from scipy import signal from scipy import signal
from torch.autograd import Variable from torch.autograd import Variable
from Device.SunnyLinker import SunnyLinker64 # from Device.SunnyLinker import SunnyLinker64
from SSMVEP.algorithm.tdca import TDCA from SSMVEP.algorithm.tdca import TDCA
from SSMVEP.algorithm.base import generate_cca_references from SSMVEP.algorithm.base import generate_cca_references
from concentration.algorithm.calculate_focus import Calculate from concentration.algorithm.calculate_focus import Calculate
@@ -17,49 +20,70 @@ from Zmq.zmqServer import zmqServer
from Zmq.zmqClient import zmqClient from Zmq.zmqClient import zmqClient
from MI.Algorithm.conformer_2class import onlineTrain from MI.Algorithm.conformer_2class import onlineTrain
from PubLibrary.InifileHelper import IniRead from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log
from SSVEP.dwfbcca import FbccaDw from SSVEP.dwfbcca import FbccaDw
from Tools.plot_MI_EEG import plotMain # from Tools.plot_MI_EEG import plotMain
from collections import deque from collections import deque
class Decoder_main(threading.Thread, device_type):
def __init__(self, device_type=None): def get_root_path():
"""
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
"""
if getattr(sys, 'frozen', False):
# 打包后:返回 exe 所在目录
return os.path.dirname(sys.executable)
else:
# 开发时:返回 py 文件所在目录
return os.path.dirname(os.path.abspath(__file__))
MODEL_FOLDER = "online_Models"
class Decoder_main(threading.Thread):
def __init__(self, device_info=None):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.device_info = {
'sample_rate': device_info['sample_rate'],
'frame_points': device_info['frame_points'],
'channel_nums': device_info['channel_nums'],
'channel_names': device_info['channel_names'],
'channel_index': device_info['channel_index'],
}
self.Runing=True self.Runing=True
self.decoder = None self.decoder = None
self.fs = 250 # 采样率
self.energy = 0 # 电量
self.status_code = 0 # 与采集设备通信的状态码0为异常1为正常
self.decoder_class = None #解码器类别 self.decoder_class = None #解码器类别
# 与采集设备通信的状态码0为异常1为正常
# self.status_code = 0
# self.device_info['sample_rate'] = 250 # 采样率
# self.energy = 0 # 电量
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果 self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
self.device_info = {
'device_type': None,
'sample_rate': None,
'channel_num': None,
}
def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None): self.zmqServer = zmqServer(device_info=self.device_info)
self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
_device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
_device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
_upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
_upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
if self.DeviceType == 1:
self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp')
self.thread_data_server.host = _device_host
self.thread_data_server.port = _device_port
self.thread_data_server.toUv = True
self.thread_data_server.start()
self.zmqServer = zmqServer()
self.zmqServer.start() self.zmqServer.start()
self.zmqClient = zmqClient(_upper_host, _upper_port) # self.zmqClient = zmqClient(_upper_host, _upper_port)
self.zmqClient.set_zmq_server(self.zmqServer) # self.zmqClient.set_zmq_server(self.zmqServer)
self.zmqClient.connect() # self.zmqClient.connect()
# def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
# self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
# _device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
# _device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
# _upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
# _upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
# if self.DeviceType == 1:
# self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.device_info['sample_rate'], 64, method='tcp')
# self.thread_data_server.host = _device_host
# self.thread_data_server.port = _device_port
# self.thread_data_server.toUv = True
# self.thread_data_server.start()
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号 def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples) # data: (chans, samples)
@@ -76,26 +100,25 @@ class Decoder_main(threading.Thread, device_type):
self.decoder_class = decoder_class self.decoder_class = decoder_class
if decoder_class == 'ssvep' or decoder_class == 'pvs': if decoder_class == 'ssvep' or decoder_class == 'pvs':
self.n_chan = 8 self.n_chan = 8
self.thread_data_server.interval_inited = False # self.thread_data_server.interval_inited = False
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue')) DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
self.ListFreq = self.zmqServer.targetFreqs self.ListFreq = self.zmqServer.targetFreqs
self.num_target = len(self.ListFreq) self.num_target = len(self.ListFreq)
if self.num_target == 0: if self.num_target == 0:
return return
# 初始化对象 二代算法 # 初始化对象 二代算法
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5, self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method) 0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
# frequence band # frequence band
self.dw.filterFrequenceBank() self.dw.filterFrequenceBank()
self.dw.setNotchFilterPara() self.dw.setNotchFilterPara()
self.calculateCount = 0 self.calculateCount = 0
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs), self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
5)
self.dw.filterInit() self.dw.filterInit()
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
elif decoder_class == 'ssmvep': elif decoder_class == 'ssmvep':
self.thread_data_server.interval_init(decoder_class) self.zmqServer.interval_init(decoder_class)
self.n_chan = 8 self.n_chan = 8
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
@@ -104,12 +127,12 @@ class Decoder_main(threading.Thread, device_type):
self.list_freqs = np.array([8, 9]) # 刺激频率 self.list_freqs = np.array([8, 9]) # 刺激频率
self.list_phase = np.array([0, 0]) # 相位 self.list_phase = np.array([0, 0]) # 相位
self.tdca = TDCA(padding_len=5, n_components=1) self.tdca = TDCA(padding_len=5, n_components=1)
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length, self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
phases=self.list_phase, n_harmonics=5) phases=self.list_phase, n_harmonics=5)
self.parameter_init(5,45) self.parameter_init(5,45)
elif decoder_class == 'mi' or decoder_class == 'ma': elif decoder_class == 'mi' or decoder_class == 'ma':
self.thread_data_server.interval_init(decoder_class) self.zmqServer.interval_init(decoder_class)
self.n_chan = 21 self.n_chan = 21
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
@@ -124,7 +147,7 @@ class Decoder_main(threading.Thread, device_type):
# self.win_len = 10 # self.win_len = 10
# self.win_step = 1 # self.win_step = 1
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue')) # self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len) # self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
# self.interval_epoch = [0, 1] # self.interval_epoch = [0, 1]
# self.parameter_init(2, 40) # self.parameter_init(2, 40)
# # self.eegQueue moved to Calculate class # # self.eegQueue moved to Calculate class
@@ -136,8 +159,8 @@ class Decoder_main(threading.Thread, device_type):
# self.total_samples = 0 # 总采样点数 # self.total_samples = 0 # 总采样点数
# self.window_ms = 600 # 检测窗口大小 (ms) # self.window_ms = 600 # 检测窗口大小 (ms)
# self.step_ms = 100 # 滑动步长 (ms) # self.step_ms = 100 # 滑动步长 (ms)
# self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点 # self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
# self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点 # self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
# self.buffer_size = self.window_samples + self.step_samples * 5 # self.buffer_size = self.window_samples + self.step_samples * 5
# self.fp1_buffer = deque(maxlen=self.buffer_size) # self.fp1_buffer = deque(maxlen=self.buffer_size)
# self.fp2_buffer = deque(maxlen=self.buffer_size) # self.fp2_buffer = deque(maxlen=self.buffer_size)
@@ -151,11 +174,11 @@ class Decoder_main(threading.Thread, device_type):
# self.double_blink_events = [] # 连续眨眼事件记录 # self.double_blink_events = [] # 连续眨眼事件记录
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳 # self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
# self.blink_events = [] # self.blink_events = []
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band') # self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
def parameter_init(self,bandPass_low,bandPass_high): def parameter_init(self,bandPass_low,bandPass_high):
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息 self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
self.trainData = [] #训练数据 self.trainData = [] #训练数据
self.trainLabel = [] #训练标签 self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据 self.plotData = [] #报告分析数据
@@ -163,13 +186,15 @@ class Decoder_main(threading.Thread, device_type):
self.currentLabel = -1 #刺激界面当前显示的训练标签 self.currentLabel = -1 #刺激界面当前显示的训练标签
self.train_started = False #是否开始训练模型 self.train_started = False #是否开始训练模型
self.load_model = False # 调用模型是否完成的标志 self.load_model = False # 调用模型是否完成的标志
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波250是采样率30是质量因子 self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波250是采样率30是质量因子
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器 self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S') fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
filePath = './online_Models/' filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
os.remove(old_pth)
self.modelPath = ''.join([filePath, fileName, '.pth']) self.modelPath = ''.join([filePath, fileName, '.pth'])
self.mp_data_queue = mp.Queue() #多进程传参队列 self.mp_data_queue = mp.Queue()
self.mp_result_queue = mp.Queue() #多进程结果队列 self.mp_result_queue = mp.Queue()
def preprocess(self, signal_data): def preprocess(self, signal_data):
# # 计算每行的平均值 # # 计算每行的平均值
@@ -192,54 +217,55 @@ class Decoder_main(threading.Thread, device_type):
# 同步信息 # 同步信息
if self.zmqServer.state_mode == 'sync': if self.zmqServer.state_mode == 'sync':
self.zmqClient.send_to_all('sync', self.zmqClient.state) # self.zmqClient.send_to_all('sync', self.zmqClient.state)
self.zmqServer.state_mode = 'rest' self.zmqServer.state_mode = 'rest'
# 状态异常,报告上位机
if self.status_code != self.thread_data_server.status_code:
self.status_code = self.thread_data_server.status_code
self.zmqClient.send_to_all('status_code', int(self.status_code))
print('status code')
# 返回电量 # # 状态异常,报告上位机
if self.energy != self.thread_data_server.energy: # if self.status_code != self.thread_data_server.status_code:
self.energy = self.thread_data_server.energy # self.status_code = self.thread_data_server.status_code
self.zmqClient.send_to_all('energy', int(self.energy)) # self.zmqClient.send_to_all('status_code', int(self.status_code))
print('energy') # print('status code')
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次 # # 返回电量
self.thread_data_server.Impedance(True) # if self.energy != self.thread_data_server.energy:
print('Impedance') # self.energy = self.thread_data_server.energy
self.zmqServer.open_Impedance = -1 # self.zmqClient.send_to_all('energy', int(self.energy))
elif self.zmqServer.open_Impedance == False: # print('energy')
self.thread_data_server.Impedance(False)
self.zmqServer.open_Impedance = -1
if self.zmqServer.get_Impedance: # 返回阻抗值 # if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
# print(self.zmqServer.get_Impedance) # self.thread_data_server.Impedance(True)
# print(self.thread_data_server.GetDataLenCount()) # print('Impedance')
if self.thread_data_server.GetDataLenCount() > 250: # self.zmqServer.open_Impedance = -1
Impe_data = self.thread_data_server.getData(250) # elif self.zmqServer.open_Impedance == False:
# 计算阻抗 # self.thread_data_server.Impedance(False)
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class) # self.zmqServer.open_Impedance = -1
self.zmqClient.send_to_all('impedance', imps.tolist())
else: # if self.zmqServer.get_Impedance: # 返回阻抗值
pass # # print(self.zmqServer.get_Impedance)
if self.zmqServer.getReport: #返回训练报告内容 # # print(self.thread_data_server.GetDataLenCount())
self.zmqServer.getReport = False # if self.thread_data_server.GetDataLenCount() > 250:
allData = np.array(self.plotData) # Impe_data = self.thread_data_server.getData(250)
allLabel = np.array(self.plotLabel) + 1 # # 计算阻抗
nTrials = min(len(allLabel),len(allData)) # imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
if nTrials < 30: # self.zmqClient.send_to_all('impedance', imps.tolist())
self.zmqClient.send_to_all('miReport',0) # else:
else: # pass
allData = allData[:nTrials] # if self.zmqServer.getReport: #返回训练报告内容
allLabel = allLabel[:nTrials] # self.zmqServer.getReport = False
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', # allData = np.array(self.plotData)
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'] # allLabel = np.array(self.plotLabel) + 1
compare_names = ['C3', 'CZ', 'C4'] # nTrials = min(len(allLabel),len(allData))
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2, # if nTrials < 30:
fs=self.fs) # self.zmqClient.send_to_all('miReport',0)
self.zmqClient.send_to_all('miReport',miReport) # else:
# allData = allData[:nTrials]
# allLabel = allLabel[:nTrials]
# ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
# 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
# compare_names = ['C3', 'CZ', 'C4']
# miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
# fs=self.device_info['sample_rate'])
# self.zmqClient.send_to_all('miReport',miReport)
# --- 取数优先:先执行 decoder消费环形缓冲再处理 plot/report 等重负载 --- # --- 取数优先:先执行 decoder消费环形缓冲再处理 plot/report 等重负载 ---
@@ -250,34 +276,33 @@ class Decoder_main(threading.Thread, device_type):
self.decoder_SSMVEP() self.decoder_SSMVEP()
elif self.decoder_class == 'mi': elif self.decoder_class == 'mi':
self.decoder_MI() self.decoder_MI()
elif self.decoder_class == 'concentration': # elif self.decoder_class == 'concentration':
self.decoder_concentration() # self.decoder_concentration()
elif self.decoder_class == 'blink': # elif self.decoder_class == 'blink':
self.decoder_blink() # self.decoder_blink()
else: # else:
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 # self.
if self.thread_data_server.GetDataLenCount() < 25: # # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
time.sleep(0.005) # # if self.thread_data_server.GetDataLenCount() < 25:
continue; # # time.sleep(0.005)
self.thread_data_server.getData(25) # # continue;
# # self.thread_data_server.getData(25)
except Exception as e: except Exception as e:
print(f"Decoder Loop Error: {e}") algo_log(f"Decoder Loop Error: {e}")
import traceback
traceback.print_exc()
time.sleep(0.1) # Prevent CPU spin if error is persistent time.sleep(0.1) # Prevent CPU spin if error is persistent
def decoder_SSVEP(self): def decoder_SSVEP(self):
if self.zmqServer.StartDecode: if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False self.zmqServer.StartDecode = False
self.decodingSteps = 1 self.decodingSteps = 1
self.thread_data_server.ResetAll() self.zmqServer.paradigmBuffer.ResetAllPara()
print('启动预测') print('启动预测')
if self.thread_data_server.GetDataLenCount() < 50: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
time.sleep(0.005) time.sleep(0.005)
return return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码 if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return return
data = self.thread_data_server.getDataViaSSVEP(50) data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
data = data[:self.n_chan, :] data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热 if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
@@ -293,7 +318,7 @@ class Decoder_main(threading.Thread, device_type):
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount)) print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
self.calculateCount = 0 self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息 if self.decodingSteps == 3: # 发送解码后的信息
self.zmqClient.send_to_all('result', int(choosenNum)) self.zmqServer.broadcast_message('result', int(choosenNum))
self.decodingSteps = 0 self.decodingSteps = 0
print('发送给界面完成。') print('发送给界面完成。')
@@ -312,19 +337,19 @@ class Decoder_main(threading.Thread, device_type):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('模型训练完成', formatted_time) print('模型训练完成', formatted_time)
self.load_model = True self.load_model = True
self.zmqClient.send_to_all('paradigm', 1) self.zmqServer.broadcast_message('paradigm', 1)
'''训练阶段采集数据''' '''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态 if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.StartTrain: if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False self.zmqServer.StartTrain = False
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.train_epoch[1] \ self.train_epoch[1] \
+ self.thread_data_server.event_inner_idx: + self.thread_data_server.event_inner_idx:
time.sleep(0.0001) time.sleep(0.0001)
return return
print('训练队列数据:', self.thread_data_server.GetDataLenCount()) print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据 trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx]) print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
@@ -348,7 +373,7 @@ class Decoder_main(threading.Thread, device_type):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time) print('启动预测 ', formatted_time)
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \ self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx: + self.thread_data_server.event_inner_idx:
time.sleep(0.0001) time.sleep(0.0001)
@@ -360,8 +385,8 @@ class Decoder_main(threading.Thread, device_type):
self.thread_data_server.event_inner_idx + self.interval_epoch[ self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]] 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
pad_eeg_test = np.zeros( pad_eeg_test = np.zeros(
(data.shape[0], int((self.sample_length + 0.1) * self.fs))) (data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
choosenNum, features_2 = self.decoder.predict(pad_eeg_test) choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray): if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0] choosenNum = choosenNum[0]
@@ -371,7 +396,7 @@ class Decoder_main(threading.Thread, device_type):
print('发送给界面完成。') print('发送给界面完成。')
else: # 休息状态 else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005) time.sleep(0.005)
return return
self.thread_data_server.getData(25) self.thread_data_server.getData(25)
@@ -419,12 +444,12 @@ class Decoder_main(threading.Thread, device_type):
if self.zmqServer.StartTrain: if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False self.zmqServer.StartTrain = False
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \ self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx: + self.thread_data_server.event_inner_idx:
time.sleep(0.0001) time.sleep(0.0001)
return return
print('训练队列数据:', self.thread_data_server.GetDataLenCount()) print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据 originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx]) print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
@@ -448,7 +473,7 @@ class Decoder_main(threading.Thread, device_type):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time) print('启动预测 ', formatted_time)
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \ self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx: + self.thread_data_server.event_inner_idx:
time.sleep(0.0001) time.sleep(0.0001)
@@ -477,128 +502,128 @@ class Decoder_main(threading.Thread, device_type):
print(f'发送给界面完成,耗时{end - start:.3f}s。') print(f'发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态 else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005) time.sleep(0.005)
return return
self.thread_data_server.getData(25) self.thread_data_server.getData(25)
def decoder_concentration(self): # def decoder_concentration(self):
if self.zmqServer.state_mode == 'predict': # if self.zmqServer.state_mode == 'predict':
if self.zmqServer.StartDecode: # if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False # self.zmqServer.StartDecode = False
self.thread_data_server.ResetAll() # self.thread_data_server.ResetAll()
now = datetime.now() # now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] # formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动专注力预测 ', formatted_time) # print('启动专注力预测 ', formatted_time)
if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果 # if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
time.sleep(0.005) # time.sleep(0.005)
return # return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码 # if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return # return
data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据 # data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
result = self.calculate.queueOpt(data) # result = self.calculate.queueOpt(data)
if result is not None: # if result is not None:
self.zmqClient.send_to_all('result', int(result)) # self.zmqClient.send_to_all('result', int(result))
else: # 休息状态 # else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
if self.thread_data_server.GetDataLenCount() < 25: # if self.thread_data_server.GetDataLenCount() < 25:
time.sleep(0.005) # time.sleep(0.005)
return # return
self.thread_data_server.getData(25) # self.thread_data_server.getData(25)
#### Blink detection ##### #### Blink detection #####
def check_double_blink(self, current_time): # def check_double_blink(self, current_time):
""" # """
检查是否检测到连续两次眨眼 # 检查是否检测到连续两次眨眼
@param current_time: 当前眨眼时间戳 # @param current_time: 当前眨眼时间戳
@return: True表示检测到连续两次眨眼 # @return: True表示检测到连续两次眨眼
""" # """
if len(self.blink_timestamps) < 2: # if len(self.blink_timestamps) < 2:
return False # return False
# 检查是否在去抖期内 # # 检查是否在去抖期内
if self.last_double_blink_time > 0: # if self.last_double_blink_time > 0:
time_since_last_double_blink = current_time - self.last_double_blink_time # time_since_last_double_blink = current_time - self.last_double_blink_time
if time_since_last_double_blink < self.double_blink_jitter: # if time_since_last_double_blink < self.double_blink_jitter:
return False # 在去抖期内,忽略连续眨眼检测 # return False # 在去抖期内,忽略连续眨眼检测
last_time = self.blink_timestamps[-1] # 当前眨眼 # last_time = self.blink_timestamps[-1] # 当前眨眼
prev_time = self.blink_timestamps[-2] # 上次眨眼 # prev_time = self.blink_timestamps[-2] # 上次眨眼
interval = last_time - prev_time # interval = last_time - prev_time
if interval <= self.double_blink_interval: # if interval <= self.double_blink_interval:
return True # return True
return False # return False
def process_blink_detection(self): # def process_blink_detection(self):
""" # """
在缓冲区数据上执行,单次眨眼检测 # 在缓冲区数据上执行,单次眨眼检测
""" # """
if len(self.fp1_buffer) < self.window_samples: # if len(self.fp1_buffer) < self.window_samples:
return # return
fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:]) # fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:]) # fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
# 计算FP1和FP2的平均 # # 计算FP1和FP2的平均
fp12_mean = (fp1_data + fp2_data) / 2.0 # fp12_mean = (fp1_data + fp2_data) / 2.0
# 带通滤波 # # 带通滤波
try: # try:
fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean) # fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
except Exception as e: # except Exception as e:
print(f"Filter error: {e}") # print(f"Filter error: {e}")
return # return
F = np.diff(fp12_filtered) # F = np.diff(fp12_filtered)
if len(F) < 3: # if len(F) < 3:
return # return
b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax) # b, d, e = blink_detection(F, self.device_info['sample_rate'], self.Dmin, self.Dmax, self.EMin, self.EMax)
if b == 1: # if b == 1:
samples_since_last = self.total_samples - self.last_blink_time # samples_since_last = self.total_samples - self.last_blink_time
time_since_last_ms = (samples_since_last / self.fs) * 1000 # time_since_last_ms = (samples_since_last / self.device_info['sample_rate']) * 1000
if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms # if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
self.blink_count += 1 # self.blink_count += 1
self.last_blink_time = self.total_samples # self.last_blink_time = self.total_samples
current_time = time.time() # current_time = time.time()
self.blink_timestamps.append(current_time) # self.blink_timestamps.append(current_time)
blink_event = { # blink_event = {
'count': self.blink_count, # 'count': self.blink_count,
'time': current_time, # 'time': current_time,
'sample_index': self.total_samples, # 'sample_index': self.total_samples,
'duration_ms': d, # 'duration_ms': d,
'energy': e # 'energy': e
} # }
self.blink_events.append(blink_event) # self.blink_events.append(blink_event)
self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机 # self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
if self.check_double_blink(current_time): # if self.check_double_blink(current_time):
self.double_blink_count += 1 # self.double_blink_count += 1
interval = self.blink_timestamps[-1] - self.blink_timestamps[-2] # interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
double_blink_event = { # double_blink_event = {
'double_blink_count': self.double_blink_count, # 'double_blink_count': self.double_blink_count,
'blink1_time': self.blink_timestamps[-2], # 'blink1_time': self.blink_timestamps[-2],
'blink2_time': self.blink_timestamps[-1], # 'blink2_time': self.blink_timestamps[-1],
'interval': interval # 'interval': interval
} # }
self.double_blink_events.append(double_blink_event) # self.double_blink_events.append(double_blink_event)
self.last_double_blink_time = current_time # self.last_double_blink_time = current_time
self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件 # self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
def decoder_blink(self): # def decoder_blink(self):
if self.thread_data_server.GetDataLenCount() < 50: # if self.thread_data_server.GetDataLenCount() < 50:
time.sleep(0.005) # time.sleep(0.005)
return # return
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
data = self.thread_data_server.get_blinkData(50) # data = self.thread_data_server.get_blinkData(50)
fp1_data = data[0, :] # ch1 (相当于FP1) # fp1_data = data[0, :] # ch1 (相当于FP1)
fp2_data = data[1, :] # ch2 (相当于FP2) # fp2_data = data[1, :] # ch2 (相当于FP2)
for i in range(len(fp1_data)): # for i in range(len(fp1_data)):
self.fp1_buffer.append(fp1_data[i]) # self.fp1_buffer.append(fp1_data[i])
self.fp2_buffer.append(fp2_data[i]) # self.fp2_buffer.append(fp2_data[i])
self.total_samples += 1 # self.total_samples += 1
self.sample_counter += 1 # self.sample_counter += 1
if self.sample_counter >= self.step_samples: # if self.sample_counter >= self.step_samples:
self.process_blink_detection() # self.process_blink_detection()
self.sample_counter = 0 # self.sample_counter = 0
def stop(self): def stop(self):
''' '''

View File

@@ -65,6 +65,51 @@ class ParadigmRingBuffer:
''' '''
return self.nUpdate return self.nUpdate
# ========== 各范式数据访问接口 ==========
def get_MIData(self):
"""获取MI导联数据 (21通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_SSMVEPData(self):
"""获取SSMVEP导联数据 (8通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getDataViaSSVEP(self, count):
"""获取SSVEP数据 (8通道 + 事件)"""
data = self.getData(count)
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_concentrateData(self, count):
"""获取专注力数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_blinkData(self, count):
"""获取眨眼数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
# reset buffer # reset buffer
def resetAllPara(self): def resetAllPara(self):
@@ -73,5 +118,3 @@ class ParadigmRingBuffer:
self.readPtr = 0 # add by lizhenhua 清空读指针 self.readPtr = 0 # add by lizhenhua 清空读指针
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区 self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区

View File

@@ -1,16 +1,22 @@
import ast
import numpy as np import numpy as np
import zmq
import threading import threading
import json import json
import queue import queue
from typing import Dict
# from Device.SunnyLinker import SunnyLinker64 # from Device.SunnyLinker import SunnyLinker64
from dataBuffer import ParadigmRingBuffer from Zmq.dataBuffer import ParadigmRingBuffer
from filterProcess import FilterRingBuffer from Zmq.filterProcess import FilterRingBuffer
from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log from logs.log import algo_log
import zmq
class zmqServer(threading.Thread): class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None): def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.device_info = device_info
self.host = host self.host = host
self.cmd_port = cmd_port # 命令交互端口 self.cmd_port = cmd_port # 命令交互端口
self.data_port = data_port # 数据接收端口 self.data_port = data_port # 数据接收端口
@@ -28,23 +34,22 @@ class zmqServer(threading.Thread):
self.daemon = True self.daemon = True
# 范式数据缓存 # 范式数据缓存
self.paradigmBuffer = ParadigmRingBuffer(66, 2500) self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
self.filterBuffer = FilterRingBuffer(66, 2500) self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
# 命令与数据通信 # 命令与数据通信
self.context = zmq.Context() self.context = zmq.Context()
# 指令通道 (8099) - ROUTER短JSON命令低频率 # 指令通道 (8099) - ROUTER短JSON命令低频率
self.cmd_socket = self.context.socket(zmq.ROUTER) self.cmd_socket = self.context.socket(zmq.ROUTER)
self.cmd_socket.setsockopt(zmq.RCVHWM, 100) # 指令不需要大缓存100条足够 # 通用套接字选项:仍在 SocketOption 中
self.cmd_socket.setsockopt(zmq.SNDHWM, 100) self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100)
self.cmd_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法降低指令延迟 self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100)
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}") self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
# 数据通道 (8100) - ROUTER高频脑电二进制流 # 数据通道 (8100) - ROUTER高频脑电二进制流
self.data_socket = self.context.socket(zmq.ROUTER) self.data_socket = self.context.socket(zmq.ROUTER)
self.data_socket.setsockopt(zmq.RCVHWM, 500) # 500包=10秒缓存足够应对短时卡顿 self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500)
self.data_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法减少数据传输延迟
self.data_socket.bind(f"tcp://{self.host}:{data_port}") self.data_socket.bind(f"tcp://{self.host}:{data_port}")
# Poller 轮训器(保持不变) # Poller 轮训器(保持不变)
@@ -65,6 +70,77 @@ class zmqServer(threading.Thread):
self.data_clients = set() # 数据端口客户端ID self.data_clients = set() # 数据端口客户端ID
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播) self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
# 范式buffer参数, 事件检测相关
self._event_lock = threading.Lock()
self._epoch_finished = False
self._event_inner_idx = -1
self.pack_contain_event = False
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.count_events = {}
self.latency = 50
self.train_latency = 50
self._interval_inited = False
@property
def interval_inited(self):
return self._interval_inited
@interval_inited.setter
def interval_inited(self, value):
self._interval_inited = value
@property
def epoch_finished(self):
with self._event_lock:
return self._epoch_finished
@epoch_finished.setter
def epoch_finished(self, value):
with self._event_lock:
self._epoch_finished = value
@property
def event_inner_idx(self):
with self._event_lock:
return self._event_inner_idx
@event_inner_idx.setter
def event_inner_idx(self, value):
with self._event_lock:
self._event_inner_idx = value
def interval_init(self, decoder_class):
if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
self.latency = (self.interval_epoch[
1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
self.train_epoch = self.interval_epoch.copy()
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;
self.train_latency = self.latency
print('时间窗:', (interval_epoch))
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
self.event_inner_idx = -1 # event在5位数据包内部的idx
self.epoch_finished = False # 接收epoch是否完整
self.pack_contain_event = False # 当前包是否含有event
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.interval_inited = True
# if getattr(self, 'serial', None) and self.serial.is_open:
# self.serial.close()
# self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口
def broadcast_message(self, method, params): def broadcast_message(self, method, params):
"""Put message into queue to be sent to all command clients""" """Put message into queue to be sent to all command clients"""
self.send_queue.put((method, params)) self.send_queue.put((method, params))
@@ -78,15 +154,15 @@ class zmqServer(threading.Thread):
# 注册新的命令客户端 # 注册新的命令客户端
if ident not in self.cmd_clients: if ident not in self.cmd_clients:
self.cmd_clients.add(ident) self.cmd_clients.add(ident)
print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})") algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
# 解析消息 # 解析消息
try: try:
message = json.loads(message_bytes.decode('utf-8')) message = json.loads(message_bytes.decode('utf-8'))
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"Invalid JSON from CMD client {ident}") algo_log(f"Invalid JSON from CMD client {ident}")
continue return
print(f"Received CMD request: {message}") algo_log(f"Received CMD request: {message}")
method = message.get("method") method = message.get("method")
params = message.get("params") params = message.get("params")
@@ -94,37 +170,40 @@ class zmqServer(threading.Thread):
# 原有命令处理逻辑 # 原有命令处理逻辑
if method == "sync": if method == "sync":
self.state_mode = 'sync' self.state_mode = 'sync'
if method == "targetFreqs": elif method == "targetFreqs":
if not isinstance(params, list): if not isinstance(params, list):
print('targetFreqs must be a list') algo_log(f"targetFreqs must be a list")
continue return
if params != self.targetFreqs: if params != self.targetFreqs:
self.targetFreqs = params self.targetFreqs = params
self.changeTarget = True self.changeTarget = True
if method == "decoderClass": elif method == "decoderClass":
if not isinstance(params, str): if not isinstance(params, str):
print('decoderClass must be a str') algo_log(f"decoderClass must be a str")
continue return
if params != self.decoder_class: if params != self.decoder_class:
self.decoder_class = params self.decoder_class = params
self.decoder_switch = True self.decoder_switch = True
if method == "getReport": elif method == "train":#训练状态
self.getReport = True
if method == "train":#训练状态
self.state_mode = 'train' self.state_mode = 'train'
self.StartTrain = True self.StartTrain = True
self.currentLabel = params # 当前刺激端的训练标签 self.currentLabel = params # 当前刺激端的训练标签
self.sunnyLinker.push_trigger(self.labels[self.currentLabel]) # self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
elif method == "predict":#预测状态 elif method == "predict":#预测状态
self.state_mode = 'predict' self.state_mode = 'predict'
if params == 1: #开始解码 if params == 1: #开始解码
self.StartDecode = True self.StartDecode = True
self.sunnyLinker.push_trigger(0x63) # self.sunnyLinker.push_trigger(0x63)
elif params == 2: #停止解码 elif params == 2: #停止解码
self.IsExitApp = True self.IsExitApp = True
self.running = False self.running = False
elif method == "rest": #休息状态 elif method == "rest": #休息状态
self.state_mode = 'rest' self.state_mode = 'rest'
else:
algo_log(f"未知命令:{method}", level="WARNING")
# elif method == "getReport":
# self.getReport = True
# elif method == "impedance": # elif method == "impedance":
# if params == 1: # if params == 1:
# self.open_Impedance = True # 开启阻抗 # self.open_Impedance = True # 开启阻抗
@@ -138,48 +217,49 @@ class zmqServer(threading.Thread):
处理8100端口原始脑电二进制数据 处理8100端口原始脑电二进制数据
固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区 固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区
""" """
# 1. 校验ZMQ消息帧完整性 # 1. 校验ZMQ消息帧完整性ROUTER接收DEALER消息的帧格式[客户端ID, 发送方ID, 空帧, 数据帧]
if len(frames) < 3: if len(frames) < 4: # 至少需要4帧
print(f"[ERROR] 无效数据帧长度不足3帧实际长度={len(frames)}") algo_log(f"Invalid data frame: 帧数量不足期望≥4实际{len(frames)}", level="ERROR")
return return
ident, _, data_bytes = frames[:3] # 2. 正确解析帧适配DEALER→ROUTER的帧格式
client_ident, sender_ident, empty_sep, data_bytes = frames[:4]
if empty_sep != b'': # 校验空分隔帧
algo_log(f"Invalid frame separator: 期望空字节,实际{empty_sep}", level="ERROR")
return
# 2. 客户端管理(单客户端场景,自动更新最新身份) # 3. 客户端管理(单客户端场景,自动更新最新身份)
if ident not in self.data_clients: if client_ident not in self.data_clients:
self.data_clients.add(ident) self.data_clients.add(client_ident)
self.current_data_client = ident # 保存唯一客户端身份,用于后续回复滤波结果 self.current_data_client = client_ident # 保存唯一客户端身份,用于后续回复滤波结果
print(f"[INFO] 新数据客户端连接成功:{ident}") print(f"[INFO] 新数据客户端连接成功:{client_ident}")
try: try:
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节与int32字节数相同 # 4. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节
EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节 EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节
if len(data_bytes) != EXPECTED_BYTES: if len(data_bytes) != EXPECTED_BYTES:
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节") algo_log(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
return return
# 4. 零拷贝二进制解析 + 维度转换 # 5. 零拷贝二进制解析 + 维度转换
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
data_np = np.frombuffer(data_bytes, dtype=np.float32) data_np = np.frombuffer(data_bytes, dtype=np.float32)
# 重塑为上位机原始维度 data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
data_np = data_np.reshape(5, 66)
# 转置为(通道数, 采样点数)标准格式转换为float64保证滤波运算精度
data_np = data_np.T.astype(np.float64) data_np = data_np.T.astype(np.float64)
# 5. 同时写入双环形缓冲区方法名与现有类保持一致appendBuffer # 6. 写入缓冲区
# 注意:上位机已发送微伏物理量,无需再乘以增益系数
self.paradigmBuffer.appendBuffer(data_np) self.paradigmBuffer.appendBuffer(data_np)
self.filterBuffer.appendBuffer(data_np) self.filterBuffer.appendBuffer(data_np)
# 生产环境必须注释每秒50次打印会导致CPU占用飙升30%以上 algo_log(f"数据写入成功shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG")
algo_log(f"数据写入成功shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG", record_once=True)
except Exception as e: except Exception as e:
algo_log(f"数据处理失败:{str(e)}", level="ERROR") algo_log(f"数据处理失败:{str(e)}", level="ERROR")
# 调试阶段临时打开,生产环境务必注释 if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG':
import traceback import traceback
traceback.print_exc() traceback.print_exc()
def _process_send_queue(self): def _process_send_queue(self):
"""处理发送队列,向所有命令客户端广播消息""" """处理发送队列,向所有命令客户端广播消息"""
while not self.send_queue.empty(): while not self.send_queue.empty():
@@ -207,15 +287,15 @@ class zmqServer(threading.Thread):
def run(self): def run(self):
self.running = True self.running = True
print(f"ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}") algo_log(f"algo ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}", level="INFO")
try: try:
while self.running: while self.running:
# 1. 处理发送队列(命令端口广播) # 1. 处理发送队列(命令端口广播)
self._process_send_queue() self._process_send_queue()
# 2. 轮训监听两个Socket的输入事件10ms超时避免阻塞 # 2. 轮训监听两个Socket的输入事件
socks = dict(self.poller.poll(10)) socks = dict(self.poller.poll(50))
# 处理命令端口消息 # 处理命令端口消息
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN: if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:

View File

@@ -1,445 +0,0 @@
import numpy as np
import zmq
import threading
import json
import queue
import time
from Device.SunnyLinker import SunnyLinker64, RingBuffer
from collections import deque
class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100):
threading.Thread.__init__(self)
self.host = host
self.cmd_port = cmd_port
self.data_port = data_port
self.running = False
self.get_Impedance = False
self.open_Impedance = None
self.StartDecode = False
self.StartTrain = False
self.state_mode = None
self.currentLabel = -1
self.IsExitApp = False
self.getReport = False
self.daemon = True
# ZMQ Context
self.context = zmq.Context()
# 指令通道 (8099) - ROUTER
self.cmd_socket = self.context.socket(zmq.ROUTER)
self.cmd_socket.setsockopt(zmq.RCVHWM, 1000)
self.cmd_socket.setsockopt(zmq.SNDHWM, 1000)
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
# 数据通道 (8100)) - ROUTER
self.data_socket = self.context.socket(zmq.ROUTER)
self.data_socket.setsockopt(zmq.RCVHWM, 1000)
self.data_socket.setsockopt(zmq.RCVTIMEO, 50)
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
self.targetFreqs = []
self.changeTarget = False
self.sunnyLinker = SunnyLinker64(None, None, None, None, None)
self.labels = [0x01, 0x02, 0x03]
self.decoder_switch = False
self.decoder_class = None
self.cmd_clients = set()
self.data_clients = set()
self.send_queue = queue.Queue()
# ========== 数据缓冲区 (RingBuffer) ==========
# 与 SunnyLinker 保持一致,使用 RingBuffer
# 66 = 64 EEG通道 + 1 事件通道(第65) + 1 标签序号通道(第66)
# 缓存约 10 秒数据 (250Hz * 10s = 2500 点)
self.n_chan = 66
self.t_buffer = 10.0 # 缓冲区时长(秒)
self.__ringBuffer = RingBuffer(self.n_chan, int(self.t_buffer * 250))
# 事件检测相关
self._event_lock = threading.Lock()
self._epoch_finished = False
self._event_inner_idx = -1
self.pack_contain_event = False
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.count_events = {}
self.latency = 50
self.train_latency = 50
# 当前事件标签序号 (从第66通道获取)
self.current_label_index = 0
# 初始化标志
self._interval_inited = False
self._currentLabel = -1
# 注册的客户端(兼容旧接口)
self.clients = set()
# ========== 事件属性:线程安全访问 ==========
@property
def epoch_finished(self):
with self._event_lock:
return self._epoch_finished
@epoch_finished.setter
def epoch_finished(self, value):
with self._event_lock:
self._epoch_finished = value
@property
def event_inner_idx(self):
with self._event_lock:
return self._event_inner_idx
@event_inner_idx.setter
def event_inner_idx(self, value):
with self._event_lock:
self._event_inner_idx = value
@property
def interval_inited(self):
return self._interval_inited
@interval_inited.setter
def interval_inited(self, value):
self._interval_inited = value
@property
def currentLabel(self):
return self._currentLabel
@currentLabel.setter
def currentLabel(self, value):
self._currentLabel = value
def broadcast_message(self, method, params):
"""Put message into queue to be sent to all connected clients"""
self.send_queue.put((method, params))
# ========== 数据缓冲区操作接口 ==========
def GetDataLenCount(self):
"""返回缓冲区当前数据点数"""
return self.__ringBuffer.nUpdate
def getData(self, count):
"""获取最新count个数据点不消费只读"""
with self.__ringBuffer.RingBufferLock:
count = min(count, self.__ringBuffer.nUpdate)
if count == 0:
return np.zeros((self.n_chan, 0))
# 计算读取范围(从尾部取最新数据)
read_end = (self.__ringBuffer.currentPtr - 1) % self.__ringBuffer.n_points
read_start = (read_end - count + 1) % self.__ringBuffer.n_points
if self.__ringBuffer.currentPtr == 0:
read_start = self.__ringBuffer.n_points - count
read_end = self.__ringBuffer.n_points - 1
if read_start <= read_end:
data = self.__ringBuffer.buffer[:, read_start:read_end + 1]
else:
part1 = self.__ringBuffer.buffer[:, read_start:]
part2 = self.__ringBuffer.buffer[:, :read_end + 1]
data = np.concatenate((part1, part2), axis=1)
return data
def consumeData(self, count):
"""消费(丢弃)指定数量的数据点,从头部移除"""
with self.__ringBuffer.RingBufferLock:
count = min(count, self.__ringBuffer.nUpdate)
self.__ringBuffer.readPtr = (self.__ringBuffer.readPtr + count) % self.__ringBuffer.n_points
self.__ringBuffer.nUpdate -= count
def ResetAll(self):
"""重置缓冲区"""
with self.__ringBuffer.RingBufferLock:
self.__ringBuffer.resetAllPara()
with self._event_lock:
self._epoch_finished = False
self._event_inner_idx = -1
self.pack_contain_event = False
self.count_events.clear()
self.current_label_index = 0
def reset_data_buffer(self):
self.ResetAll()
def reset_state(self):
self.ResetAll()
def interval_init(self, decoder_class):
"""初始化事件检测参数"""
import ast
from PubLibrary.InifileHelper import IniRead
if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.interval_epoch = [int(i * 250) for i in interval_epoch]
self.train_epoch = [int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * 250)]
self.latency = (self.interval_epoch[1] + 0.1 * 250) // 5
self.train_latency = (self.train_epoch[1] + 0.1 * 250) // 5
elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.interval_epoch = [int(i * 250) for i in interval_epoch]
self.train_epoch = self.interval_epoch.copy()
self.latency = self.interval_epoch[1] // 5
self.train_latency = self.latency
self.count_events = {}
self._event_inner_idx = -1
self._epoch_finished = False
self.pack_contain_event = False
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self._interval_inited = True
# ========== 事件检测 ==========
def detect_event(self, data_matrix):
"""
检测事件通道中的触发信号
@param data_matrix: shape (66, N) - N个采样点的数据
第65行(索引64) = 事件通道
第66行(索引65) = 标签通道
@return: 是否检测到事件
"""
if data_matrix.shape[1] == 0:
return False
self.pack_contain_event = False
event_channel = data_matrix[64, :] # 第65通道 = 标签值(event值)
label_channel = data_matrix[65, :] # 第66通道 = 标签序号(label index)
events = event_channel.tolist()
with self._event_lock:
self._event_inner_idx = -1
self.current_event_label = 0
for idx, event in enumerate(events):
if int(event) in self.events:
self._event_inner_idx = idx
self.current_label_index = int(label_channel[idx])
self.pack_contain_event = True
new_key = f"{event}_{time.time()}"
latency = self.latency if event == self.predict_event else self.train_latency
self.count_events[new_key] = latency + 1
# 延迟计数递减
drop_items = []
for key, value in self.count_events.items():
value = value - 1
if value == 0:
drop_items.append(key)
self.count_events[key] = value
for key in drop_items:
del self.count_events[key]
if drop_items:
self._epoch_finished = True
# 检测到事件时清除RingBuffer中之前的数据只保留当前包
if self.pack_contain_event:
self.__ringBuffer.resetAllPara()
return True
self._epoch_finished = False
return False
def run(self):
self.running = True
print(f"Server running - CMD: {self.cmd_port}, DATA: {self.data_port}")
cmd_poller = zmq.Poller()
cmd_poller.register(self.cmd_socket, zmq.POLLIN)
data_poller = zmq.Poller()
data_poller.register(self.data_socket, zmq.POLLIN)
try:
while self.running:
# --- 处理发送队列 (指令通道) ---
while not self.send_queue.empty():
method, params = self.send_queue.get()
if self.cmd_clients:
try:
msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8')
for client_id in list(self.cmd_clients):
try:
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
except Exception:
pass
except Exception:
pass
# --- 处理指令通道 ---
socks = dict(cmd_poller.poll(10))
if self.cmd_socket in socks:
self._handle_cmd_socket()
# --- 处理数据通道 ---
socks = dict(data_poller.poll(10))
if self.data_socket in socks:
self._handle_data_socket()
except Exception as e:
print(f"Server error: {e}")
finally:
self.running = False
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
def _handle_cmd_socket(self):
"""处理指令通道消息"""
try:
frames = self.cmd_socket.recv_multipart()
if len(frames) < 3:
return
ident, _, message_bytes = frames[:3]
self.cmd_clients.add(ident)
self.clients.add(ident)
message = json.loads(message_bytes.decode('utf-8'))
method = message.get("method")
params = message.get("params")
print(f"[CMD] {method}: {params}")
if method == "sync":
self.state_mode = 'sync'
elif method == "targetFreqs":
if isinstance(params, list) and params != self.targetFreqs:
self.targetFreqs = params
self.changeTarget = True
elif method == "decoderClass":
if isinstance(params, str) and params != self.decoder_class:
self.decoder_class = params
self.decoder_switch = True
elif method == "getReport":
self.getReport = True
elif method == "train":
self.state_mode = 'train'
self.StartTrain = True
self.currentLabel = params
elif method == "predict":
self.state_mode = 'predict'
if params == 1:
self.StartDecode = True
elif params == 2:
self.IsExitApp = True
self.running = False
elif method == "rest":
self.state_mode = 'rest'
elif method == "impedance":
if params == 1:
self.open_Impedance = True
self.get_Impedance = True
elif params == 2:
self.open_Impedance = False
self.get_Impedance = False
except Exception as e:
print(f"CMD socket error: {e}")
def _handle_data_socket(self):
"""处理数据通道消息 (EEG数据)
上位机数据格式:
- 数据帧: [identity, '', meta_json, data_buffer]
data_buffer = [N, 66] float32 -> 转置为 [66, N]
"""
try:
frames = self.data_socket.recv_multipart()
if len(frames) < 4:
return
ident, _, message_bytes = frames[:3]
self.data_clients.add(ident)
meta = json.loads(message_bytes.decode('utf-8'))
# data: [N, 66] -> 转置 -> [66, N]
raw_data = np.frombuffer(frames[3], dtype=np.float32)
n_samples, n_channels = meta.get('shape', [5, 66])
data_matrix = raw_data.reshape(n_samples, n_channels).T.astype(np.float32)
# 写入 RingBuffer
with self.__ringBuffer.RingBufferLock:
self.__ringBuffer.appendBuffer(data_matrix)
# 事件检测
self.detect_event(data_matrix)
except Exception as e:
print(f"DATA socket error: {e}")
# ========== 各范式数据访问接口 ==========
def get_MIData(self):
"""获取MI导联数据 (21通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_SSMVEPData(self):
"""获取SSMVEP导联数据 (8通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getDataViaSSVEP(self, count):
"""获取SSVEP数据 (8通道 + 事件)"""
data = self.getData(count)
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_concentrateData(self, count):
"""获取专注力数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_blinkData(self, count):
"""获取眨眼数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getImpedance(self, data, decoder_class):
"""计算阻抗ZMQ模式下不可用"""
return np.zeros(8)
def stop(self):
self.running = False
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
if __name__ == '__main__':
server = zmqServer()
server.start()

View File

@@ -8,6 +8,7 @@ import os
# import logging # import logging
import base64 import base64
import io import io
import math
# logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
# #
@@ -22,7 +23,7 @@ import io
class Calculate(): class Calculate():
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10): def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10, config=None):
self.Threshold_value_low = Threshold_value_low self.Threshold_value_low = Threshold_value_low
self.Threshold_value_high = Threshold_value_high self.Threshold_value_high = Threshold_value_high
self.fs = fs self.fs = fs
@@ -31,47 +32,73 @@ class Calculate():
self.EVI_result = [] self.EVI_result = []
self.eegQueue = deque(maxlen=win_len) self.eegQueue = deque(maxlen=win_len)
# # 存储历史数据用于绘图
# self.beta_history = []
# self.alpha_history = []
# self.theta_history = []
# self.focus_history = []
# self.timestamp_history = []
#
# # 记录开始时间
# self.start_time = None
# self.recording = False
#
# # 图表保存路径
# self.chart_dir = "reports"
# if not os.path.exists(self.chart_dir):
# os.makedirs(self.chart_dir)
# print(f"[调试] 创建目录: {self.chart_dir}")
# 初始化滤波器 # 初始化滤波器
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False) self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
self.last_focus = None
# 异步滤波系数配置(核心手感控制纽)
self.alpha_up = 1 # 上升系数:较小,保证分数平滑爬升,过滤偶发的瞬时高能量
# alpha_down / shrink_factor 从 config.ini 读取,方便上位机调参
if config:
self.alpha_down = float(config.get('alpha_down', 0.8))
self.shrink_factor = float(config.get('shrink_factor', 0.5))
else:
self.alpha_down = 0.8
self.shrink_factor = 0.5
print("[调试] Calculate 类初始化完成") print("[调试] Calculate 类初始化完成")
def calculate_focus(self, beta, alpha, theta): def calculate_focus(self, beta, alpha, theta):
""" """
专注度计算 - 固定映射版本 专注度计算 - 三区间门限异步滤波版本
""" """
# 0. 频带特征预处理
theta_mod = theta ** 0.7
# 原始比值 # 原始比值
raw = beta / (alpha + theta + 1e-10) raw = beta / (alpha + theta_mod + 1e-10)
# Sigmoid 映射:让 raw 在 0.3-1.5 区间敏感 exponent = 2.0
# 参数可调:
# k = 12 (斜率,越大越陡)
# x0 = 0.6 (中心点raw=0.6时focus≈50)
k = 12.0
x0 = 0.6
focus = 100.0 / (1.0 + np.exp(-k * (raw - x0)))
# 可选:添加滑动平均平滑 # 1. 防止脑电比值出现负数异常值
raw_input = max(raw, 0.0)
# 2. 2次幂纵轴压缩映射 (shrink_factor 从 config.ini 读取)
focus_raw = 100 * self.shrink_factor * (raw_input ** exponent)
# 3. 计算当前帧的瞬时分数 (基准量级 0-120)
instant_focus = 120 * (1.0 - np.exp(-focus_raw / 100.0))
# 4. 核心修改:三区间门限时域滤波
if self.last_focus is None:
# 冷启动:首帧直接赋值
focus = instant_focus
else:
# 判断当前瞬时分数是否处于【极端区】(80以上 或 60以下)
if instant_focus > 85.0 or instant_focus < 60.0:
# 执行异步低通时域滤波
if instant_focus >= self.last_focus:
# 趋势上升:慢爬升
focus = self.alpha_up * instant_focus + (1 - self.alpha_up) * self.last_focus
else:
# 趋势下降:快跌落
focus = self.alpha_down * instant_focus + (1 - self.alpha_down) * self.last_focus
else:
# 【高灵敏自由区】(60 <= instant_focus <= 80)
# 不执行异步滤波,分数直接跟随瞬时值,保证中间状态绝对跟手
focus = instant_focus
# 5. 更新历史状态缓存
self.last_focus = focus
# 打印在线调试日志,方便观察区间切换
zone_tag = "极端区(滤波)" if (instant_focus > 80 or instant_focus < 60) else "自由区(直通)"
print(f"原始特征比值 raw: {raw:.4f} | 瞬时分数: {instant_focus:.1f} | 滤波后分数: {focus:.1f}")
# 最终返回整型
return int(focus) return int(focus)
def calculate_all(self, data, fs, nperseg=1000): def calculate_all(self, data, fs, nperseg=1000):
mean_x = np.mean(data, axis=-1, keepdims=True) mean_x = np.mean(data, axis=-1, keepdims=True)
data = data - mean_x data = data - mean_x
@@ -319,14 +346,16 @@ class Calculate():
if eegData.size == 0: if eegData.size == 0:
return None return None
eegData -= np.mean(eegData, axis=-1, keepdims=True) eegData -= np.mean(eegData, axis=-1, keepdims=True)
eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) # eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) # 陷波
eegData = signal.lfilter(self.b_design, 1, eegData) # eegData = signal.lfilter(self.b_design, 1, eegData) # 滤波
focus_score, CLI_score, beta, alpha, theta = self.calculate_all(eegData, fs=self.fs, nperseg=1000) focus_score, CLI_score, beta_psd, alpha_psd, theta_psd = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
# self.add_data_point(focus_score, beta, alpha, theta) # self.add_data_point(focus_score, beta_psd, alpha_psd, theta_psd) # 已注释(方法已移除)
# return (focus_score)
return (focus_score, beta_psd)
# return None
return focus_score
return None
class Calculate2(): class Calculate2():

View File

@@ -13,9 +13,6 @@ Num_blocks = 1
Num_trials = 10 Num_trials = 10
Audio_device = 0 Audio_device = 0
Rest_time = 2 Rest_time = 2
Device_type = 1
Device_Host = 127.0.0.1
Device_Port = 5086
Upper_Host = 127.0.0.1 Upper_Host = 127.0.0.1
Upper_Port = 8088 Upper_Port = 8088
Serial_port = COM44 Serial_port = COM44
@@ -24,148 +21,8 @@ console_output = 1
; 64 导设备配置 ; 64 导设备配置
[device_type_1] [device_type_1]
device_sample_rate = 250 sample_rate = 250
device_channel_nums = 66 frame_points = 5
device_channel_names = ['FP1', 'FP2', 'FC1', 'FC2', 'CP1', 'CP2', 'F3', 'F4', 'P3', 'P4', 'O1', 'O2', 'FT9', 'FT10', 'F7', 'F8', 'TP9', 'TP10', 'AF4', 'PO8', 'PZ', 'FCZ'] channel_nums = 66
device_channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18] channel_names = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1', 'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6', 'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1', 'T7', 'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ', 'PZ', 'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1', 'label', 'label_tag']
channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
[Layout]
main_splitter_left = 993
main_splitter_right = 922
right_splitter_left = 233
right_splitter_right = 771
left_splitter_left = 503
left_splitter_right = 501q
[channel]
channel_x_fp1 = 419
channel_y_fp1 = 124
channel_x_fc1 = 439
channel_y_fc1 = 296
channel_x_fp2 = 576
channel_y_fp2 = 124
channel_x_fc2 = 556
channel_y_fc2 = 299
channel_x_f3 = 397
channel_y_f3 = 231
channel_x_cp1 = 439
channel_y_cp1 = 426
channel_x_f4 = 601
channel_y_f4 = 232
channel_x_cp2 = 559
channel_y_cp2 = 425
channel_x_fc3 = 379
channel_y_fc3 = 295
channel_x_af4 = 571
channel_y_af4 = 171
channel_x_po8 = 645
channel_y_po8 = 564
channel_x_fpz = 499
channel_y_fpz = 112
channel_x_fcz = 499
channel_y_fcz = 300
channel_x_poz = 500
channel_y_poz = 554
channel_x_po5 = 387
channel_y_po5 = 551
channel_x_po6 = 611
channel_y_po6 = 551
channel_x_c3 = 373
channel_y_c3 = 363
channel_x_fc5 = 319
channel_y_fc5 = 292
channel_x_c4 = 620
channel_y_c4 = 363
channel_x_fc6 = 676
channel_y_fc6 = 288
channel_x_p3 = 398
channel_y_p3 = 491
channel_x_cp5 = 322
channel_y_cp5 = 430
channel_x_p4 = 600
channel_y_p4 = 489
channel_x_cp6 = 678
channel_y_cp6 = 430
channel_x_c5 = 313
channel_y_c5 = 361
channel_x_f6 = 650
channel_y_f6 = 223
channel_x_f5 = 349
channel_y_f5 = 224
channel_x_po4 = 573
channel_y_po4 = 551
channel_x_po3 = 429
channel_y_po3 = 550
channel_x_cp4 = 619
channel_y_cp4 = 424
channel_x_cp3 = 381
channel_y_cp3 = 426
channel_x_fc4 = 619
channel_y_fc4 = 295
channel_x_o1 = 423
channel_y_o1 = 598
channel_x_ft9 = 252
channel_y_ft9 = 168
channel_x_o2 = 576
channel_y_o2 = 597
channel_x_ft10 = 798
channel_y_ft10 = 277
channel_x_f7 = 295
channel_y_f7 = 214
channel_x_tp9 = 202
channel_y_tp9 = 445
channel_x_f8 = 701
channel_y_f8 = 215
channel_x_t7 = 252
channel_y_t7 = 362
channel_x_tp7 = 261
channel_y_tp7 = 436
channel_x_ft8 = 734
channel_y_ft8 = 283
channel_x_ft7 = 264
channel_y_ft7 = 286
channel_x_af8 = 645
channel_y_af8 = 159
channel_x_af7 = 351
channel_y_af7 = 160
channel_x_p6 = 652
channel_y_p6 = 499
channel_x_p5 = 348
channel_y_p5 = 499
channel_x_c6 = 683
channel_y_c6 = 362
channel_x_f1 = 447
channel_y_f1 = 236
channel_x_t8 = 745
channel_y_t8 = 361
channel_x_f2 = 549
channel_y_f2 = 235
channel_x_p7 = 300
channel_y_p7 = 505
channel_x_c1 = 435
channel_y_c1 = 363
channel_x_p8 = 698
channel_y_p8 = 508
channel_x_c2 = 559
channel_y_c2 = 359
channel_x_fz = 499
channel_y_fz = 238
channel_x_po7 = 354
channel_y_po7 = 562
channel_x_tp8 = 735
channel_y_tp8 = 438
channel_x_oz = 498
channel_y_oz = 609
channel_x_af3 = 428
channel_y_af3 = 170
channel_x_pz = 501
channel_y_pz = 486
channel_x_p2 = 551
channel_y_p2 = 483
channel_x_cz = 499
channel_y_cz = 361
channel_x_p1 = 448
channel_y_p1 = 488

138
datamock.py Normal file
View File

@@ -0,0 +1,138 @@
import zmq
import numpy as np
import time
from datetime import datetime
# ========== 参数配置 ==========
FS = 250 # 采样率 Hz
N_SAMPLES_PER_PKT = 5 # 每包采样点数
N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
EEG_FREQ = 10 # EEG 正弦波频率 Hz
EEG_AMP = 100.0 # EEG 幅值 100μV
LABEL_INTERVAL = 5 # 标签间隔秒数
SERVER_ADDR = 'tcp://127.0.0.1:8100'
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
def build_packet(global_sample_idx):
"""
生成一包 [5, 66] 的 float32 数据
:param global_sample_idx: 当前包第一个采样点在全局序列中的索引 (从 0 开始)
:return: np.ndarray shape [5, 66]
"""
# 当前包内 5 个采样点对应的时间(秒)
t = (global_sample_idx + np.arange(N_SAMPLES_PER_PKT)) / FS
# Ch0-63: EEG 10Hz 正弦波,幅值 100μV
# t shape [5,]sin 乘以标量后仍是 [5,],需要 reshape 为 [5,1] 再广播到 64 通道
eeg = (EEG_AMP * np.sin(2 * np.pi * EEG_FREQ * t)).reshape(N_SAMPLES_PER_PKT, 1) # [5, 1]
eeg = np.tile(eeg, (1, 64)) # [5, 64]
# Ch64: 标签值通道,初始化为 0
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float32)
# Ch65: 标签序号通道,初始化为 0
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float32)
# 拼成 [5, 66]
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float32)
return packet
def should_send_label(global_sample_idx):
"""
判断当前包是否包含标签触发点(每 5s 的最后一个采样点)
采样点索引从 0 开始,每 5s = 1250 个采样点
最后一个采样点索引: 1249, 2499, 3749, ...
由于每包 5 个采样点,标签点落在包内的最后一个采样点位置
即当前包起始索引 global_sample_idx 必须使得:
global_sample_idx <= 标签点索引 < global_sample_idx + N_SAMPLES_PER_PKT
也就是 global_sample_idx <= 1249 < global_sample_idx + 5
即 global_sample_idx = 1245, 2495, 3745, ...
即 global_sample_idx = n * LABEL_INTERVAL * FS - N_SAMPLES_PER_PKT
"""
samples_per_interval = LABEL_INTERVAL * FS
# 检查当前包是否包含 interval 的最后一个采样点
# 标签点索引 = n * 1250 - 1当 global_sample_idx = n*1250-5 时,标签在包内索引 4
return (global_sample_idx + N_SAMPLES_PER_PKT - 1) % samples_per_interval == samples_per_interval - 1
def main():
ctx = zmq.Context()
sock = ctx.socket(zmq.DEALER)
sock.connect(SERVER_ADDR)
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}")
global_sample_idx = 0 # 全局采样点计数器
label_type = 1 # 当前标签类型: 1 或 2
label1_count = 0 # label=1 的序号计数器
label2_count = 0 # label=2 的序号计数器
packet_count = 0 # 已发送包数
print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...")
print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms")
print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV")
print(f" 标签: 每 {LABEL_INTERVAL}s 末尾采样点触发 | label 1/2 交替")
print("-" * 50)
try:
while True:
t_start = time.perf_counter()
# 构建当前包
packet = build_packet(global_sample_idx)
# 检查是否需要放置标签
if should_send_label(global_sample_idx):
if label_type == 1:
label1_count += 1
label_value = 1
label_number = label1_count
else:
label2_count += 1
label_value = 2
label_number = label2_count
# 标签放在当前包最后一个采样点(索引 4
packet[4, 64] = label_value
packet[4, 65] = label_number
ts = datetime.now().strftime('%H:%M:%S')
print(f"[{ts}] 标签触发: label={label_value}, 序号={label_number} "
f"(global_sample_idx={global_sample_idx})")
# 交替标签类型
label_type = 2 if label_type == 1 else 1
# 发送: multipart 3帧 [identity, '', data]
sock.send_multipart([
b'datamock',
b'',
packet.tobytes()
])
# 每 50 包打印一次进度
if packet_count % 50 == 0:
ts = datetime.now().strftime('%H:%M:%S')
print(f"[{ts}] 已发送 {packet_count} 包 (global_sample_idx={global_sample_idx})")
global_sample_idx += N_SAMPLES_PER_PKT
packet_count += 1
# 精确控制发送节奏: 等待到 PKT_INTERVAL 秒
elapsed = time.perf_counter() - t_start
sleep_time = PKT_INTERVAL - elapsed
if sleep_time > 0:
time.sleep(sleep_time)
except KeyboardInterrupt:
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 停止发送,共发送 {packet_count}")
finally:
sock.close()
ctx.term()
if __name__ == '__main__':
main()

View File

@@ -1,51 +1,44 @@
# log.py
import os import os
from datetime import datetime from datetime import datetime
import logging import logging
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
import inspect # 新增导入
from PubLibrary.InifileHelper import IniRead from PubLibrary.InifileHelper import IniRead
console_output = IniRead('system', 'console_output', '1') console_output = IniRead('system', 'console_output', '1')
log_level = IniRead('system', 'algo_log_level', 'INFO') log_level = IniRead('system', 'algo_log_level', 'INFO')
# 新增日志去重缓存key为日志内容value为是否已打印
log_once_cache = set() log_once_cache = set()
# 缓存已经创建过的logger避免重复创建handler
logger_cache = {}
def init_module_logger(): def init_module_logger(logger_name):
""" log_dir = './logs/'
初始化指定模块的日志器
:return: 对应模块的logger实例
"""
# 缓存命中则直接返回
log_dir = './logs/' # 确保日志目录存在
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log') log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log')
# 初始化logger # 已创建直接返回
logger = logging.getLogger('decoderLogger') if logger_name in logger_cache:
logger.setLevel(log_level) return logger_cache[logger_name]
logger = logging.getLogger(logger_name)
logger.setLevel(log_level)
if logger.handlers: if logger.handlers:
logger_cache[logger_name] = logger
return logger return logger
# 设置日志轮转最大10个文件每个10MB
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(
log_file, log_file,
maxBytes=10*1024*1024, maxBytes=10*1024*1024,
backupCount=10, backupCount=10,
encoding='utf-8' encoding='utf-8'
) )
# 日志格式
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s', '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S' datefmt='%Y-%m-%d %H:%M:%S'
) )
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
logger.setLevel(log_level)
logger.addHandler(file_handler) logger.addHandler(file_handler)
if console_output: if console_output:
@@ -53,27 +46,26 @@ def init_module_logger():
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
logger.addHandler(console_handler) logger.addHandler(console_handler)
logger_cache[logger_name] = logger
return logger return logger
def algo_log(content, level="INFO", record_once=False):
"""
通用日志函数,支持按模块输出到不同日志文件
:param content: 日志内容
:param level: 日志级别DEBUG/INFO/WARNING/ERROR/FATAL
:param record_once: 是否只打印一次该日志内容默认False
"""
# 初始化模块日志器
logger = init_module_logger()
# 新增:处理只打印一次的逻辑 def algo_log(content, level="INFO", record_once=False):
# 向上回溯1层栈拿到调用algo_log的代码文件信息
frame = inspect.currentframe().f_back
file_path = frame.f_code.co_filename
# 提取py文件名不带后缀/带后缀自选)
file_name = os.path.basename(file_path) # 例zmqServer.py
# file_name = os.path.splitext(os.path.basename(file_path))[0] # 例zmqServer
logger = init_module_logger(file_name)
if record_once: if record_once:
# 生成唯一标识可根据需要调整比如拼接level增强唯一性
log_key = f"{level.upper()}_{content}" log_key = f"{level.upper()}_{content}"
if log_key in log_once_cache: if log_key in log_once_cache:
return # 已打印过,直接返回 return
log_once_cache.add(log_key) # 未打印过,加入缓存 log_once_cache.add(log_key)
# 根据级别输出日志
level_upper = level.upper() level_upper = level.upper()
if level_upper == "DEBUG": if level_upper == "DEBUG":
logger.debug(content) logger.debug(content)
@@ -83,5 +75,5 @@ def algo_log(content, level="INFO", record_once=False):
logger.error(content) logger.error(content)
elif level_upper == "FATAL": elif level_upper == "FATAL":
logger.fatal(content) logger.fatal(content)
else: # 默认INFO级别 else:
logger.info(content) logger.info(content)

View File

@@ -6,32 +6,33 @@ import time
from Decoder import Decoder_main from Decoder import Decoder_main
from PubLibrary.RunOnce import is_program_running from PubLibrary.RunOnce import is_program_running
from PubLibrary.InifileHelper import IniRead from PubLibrary.InifileHelper import IniRead
from logs.log import algo_log
def get_device_info(device_type): def get_device_info(device_type):
section = f'device_type_{device_type}' section = f'device_type_{device_type}'
device_info = { device_info = {
'device_sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250, 'sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250,
'frame_points': int(IniRead(section, 'frame_points')) if IniRead(section, 'frame_points') is not None else 5,
'' 'channel_nums': int(IniRead(section, 'channel_nums')) if IniRead(section, 'channel_nums') is not None else 66,
'channel_names': IniRead(section, 'channel_names') if IniRead(section, 'channel_names') is not None else None,
'channel_index': IniRead(section, 'channel_index') if IniRead(section, 'channel_index') is not None else None,
} }
return device_info
if __name__ == "__main__": if __name__ == "__main__":
if not is_program_running(): if not is_program_running():
# 解析命令行参数 # 解析命令行参数
parser = argparse.ArgumentParser(description="EEG Decoder Application") # parser = argparse.ArgumentParser(description="EEG Decoder Application")
parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type") # parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type")
# parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP") # parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
# parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port") # parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
# parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP") # parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
# parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port") # parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
# args = parser.parse_args()
args = parser.parse_args()
device_info= get_device_info(args.device_type)
decoder = Decoder_main(device_info=device_info)
# decoder.connect( # decoder.connect(
# device_type=args.device_type, # device_type=args.device_type,
# device_host=args.device_host, # device_host=args.device_host,
@@ -40,6 +41,10 @@ if __name__ == "__main__":
# upper_port=args.upper_port # upper_port=args.upper_port
# ) # )
device_info= get_device_info(1)
algo_log(f"device_info: {device_info}", level="INFO")
decoder = Decoder_main(device_info=device_info)
try: try:
decoder.start() decoder.start()
while not decoder.zmqServer.IsExitApp: while not decoder.zmqServer.IsExitApp: