Compare commits
2 Commits
67587f354b
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34c9115258 | ||
|
|
69b2802895 |
408
Decoder.py
408
Decoder.py
@@ -1,7 +1,6 @@
|
|||||||
import ast
|
import ast
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
@@ -11,7 +10,7 @@ import torch
|
|||||||
from queue import Empty
|
from queue import Empty
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
# from Device.SunnyLinker import SunnyLinker64
|
from Device.SunnyLinker import SunnyLinker64
|
||||||
from SSMVEP.algorithm.tdca import TDCA
|
from SSMVEP.algorithm.tdca import TDCA
|
||||||
from SSMVEP.algorithm.base import generate_cca_references
|
from SSMVEP.algorithm.base import generate_cca_references
|
||||||
from concentration.algorithm.calculate_focus import Calculate
|
from concentration.algorithm.calculate_focus import Calculate
|
||||||
@@ -20,46 +19,49 @@ from Zmq.zmqServer import zmqServer
|
|||||||
from Zmq.zmqClient import zmqClient
|
from Zmq.zmqClient import zmqClient
|
||||||
from MI.Algorithm.conformer_2class import onlineTrain
|
from MI.Algorithm.conformer_2class import onlineTrain
|
||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
from logs.log import algo_log
|
|
||||||
from SSVEP.dwfbcca import FbccaDw
|
from SSVEP.dwfbcca import FbccaDw
|
||||||
# from Tools.plot_MI_EEG import plotMain
|
from Tools.plot_MI_EEG import plotMain
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from Zmq.filterProcess import SlidingFilter
|
|
||||||
|
|
||||||
def get_root_path():
|
class Decoder_main(threading.Thread, device_type):
|
||||||
"""
|
def __init__(self, device_type=None):
|
||||||
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
|
|
||||||
"""
|
|
||||||
if getattr(sys, 'frozen', False):
|
|
||||||
# 打包后:返回 exe 所在目录
|
|
||||||
return os.path.dirname(sys.executable)
|
|
||||||
else:
|
|
||||||
# 开发时:返回 py 文件所在目录
|
|
||||||
return os.path.dirname(os.path.abspath(__file__))
|
|
||||||
MODEL_FOLDER = "online_Models"
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder_main(threading.Thread):
|
|
||||||
def __init__(self, device_info=None):
|
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.device_info = device_info
|
|
||||||
self.Runing=True
|
self.Runing=True
|
||||||
self.decoder = None
|
self.decoder = None
|
||||||
|
|
||||||
|
self.fs = 250 # 采样率
|
||||||
|
self.energy = 0 # 电量
|
||||||
|
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||||
self.decoder_class = None #解码器类别
|
self.decoder_class = None #解码器类别
|
||||||
|
|
||||||
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
||||||
|
self.device_info = {
|
||||||
|
'device_type': None,
|
||||||
|
'sample_rate': None,
|
||||||
|
'channel_num': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
|
||||||
|
self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
|
||||||
|
_device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
|
||||||
|
_device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
|
||||||
|
_upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
|
||||||
|
_upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
|
||||||
|
|
||||||
self.zmqServer = zmqServer(device_info=self.device_info)
|
if self.DeviceType == 1:
|
||||||
self.zmqServer.start() # 启动ZMQ接收线程
|
self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp')
|
||||||
|
self.thread_data_server.host = _device_host
|
||||||
|
self.thread_data_server.port = _device_port
|
||||||
|
|
||||||
self.sliding_filter = SlidingFilter(
|
self.thread_data_server.toUv = True
|
||||||
ring_buffer=self.zmqServer.filterBuffer,
|
self.thread_data_server.start()
|
||||||
n_chan=self.zmqServer.device_info['channel_nums'],
|
|
||||||
srate=self.zmqServer.device_info['sample_rate']
|
|
||||||
)
|
|
||||||
|
|
||||||
# 注册滤波结果回调(示例:打印数据形状)
|
self.zmqServer = zmqServer()
|
||||||
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
|
self.zmqServer.start()
|
||||||
|
|
||||||
|
self.zmqClient = zmqClient(_upper_host, _upper_port)
|
||||||
|
self.zmqClient.set_zmq_server(self.zmqServer)
|
||||||
|
self.zmqClient.connect()
|
||||||
|
|
||||||
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
||||||
# data: (chans, samples)
|
# data: (chans, samples)
|
||||||
@@ -76,25 +78,26 @@ class Decoder_main(threading.Thread):
|
|||||||
self.decoder_class = decoder_class
|
self.decoder_class = decoder_class
|
||||||
if decoder_class == 'ssvep' or decoder_class == 'pvs':
|
if decoder_class == 'ssvep' or decoder_class == 'pvs':
|
||||||
self.n_chan = 8
|
self.n_chan = 8
|
||||||
# self.thread_data_server.interval_inited = False
|
self.thread_data_server.interval_inited = False
|
||||||
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
||||||
self.ListFreq = self.zmqServer.targetFreqs
|
self.ListFreq = self.zmqServer.targetFreqs
|
||||||
self.num_target = len(self.ListFreq)
|
self.num_target = len(self.ListFreq)
|
||||||
if self.num_target == 0:
|
if self.num_target == 0:
|
||||||
return
|
return
|
||||||
# 初始化对象 二代算法
|
# 初始化对象 二代算法
|
||||||
self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
|
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5,
|
||||||
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
|
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
|
||||||
# frequence band
|
# frequence band
|
||||||
self.dw.filterFrequenceBank()
|
self.dw.filterFrequenceBank()
|
||||||
self.dw.setNotchFilterPara()
|
self.dw.setNotchFilterPara()
|
||||||
self.calculateCount = 0
|
self.calculateCount = 0
|
||||||
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
|
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs),
|
||||||
|
5)
|
||||||
self.dw.filterInit()
|
self.dw.filterInit()
|
||||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
|
|
||||||
elif decoder_class == 'ssmvep':
|
elif decoder_class == 'ssmvep':
|
||||||
self.zmqServer.interval_init(decoder_class)
|
self.thread_data_server.interval_init(decoder_class)
|
||||||
self.n_chan = 8
|
self.n_chan = 8
|
||||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||||
@@ -103,12 +106,12 @@ class Decoder_main(threading.Thread):
|
|||||||
self.list_freqs = np.array([8, 9]) # 刺激频率
|
self.list_freqs = np.array([8, 9]) # 刺激频率
|
||||||
self.list_phase = np.array([0, 0]) # 相位
|
self.list_phase = np.array([0, 0]) # 相位
|
||||||
self.tdca = TDCA(padding_len=5, n_components=1)
|
self.tdca = TDCA(padding_len=5, n_components=1)
|
||||||
self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
|
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length,
|
||||||
phases=self.list_phase, n_harmonics=5)
|
phases=self.list_phase, n_harmonics=5)
|
||||||
self.parameter_init(5,45)
|
self.parameter_init(5,45)
|
||||||
|
|
||||||
elif decoder_class == 'mi' or decoder_class == 'ma':
|
elif decoder_class == 'mi' or decoder_class == 'ma':
|
||||||
self.zmqServer.interval_init(decoder_class)
|
self.thread_data_server.interval_init(decoder_class)
|
||||||
self.n_chan = 21
|
self.n_chan = 21
|
||||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||||
@@ -123,7 +126,7 @@ class Decoder_main(threading.Thread):
|
|||||||
# self.win_len = 10
|
# self.win_len = 10
|
||||||
# self.win_step = 1
|
# self.win_step = 1
|
||||||
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
|
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
|
||||||
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
|
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len)
|
||||||
# self.interval_epoch = [0, 1]
|
# self.interval_epoch = [0, 1]
|
||||||
# self.parameter_init(2, 40)
|
# self.parameter_init(2, 40)
|
||||||
# # self.eegQueue moved to Calculate class
|
# # self.eegQueue moved to Calculate class
|
||||||
@@ -135,8 +138,8 @@ class Decoder_main(threading.Thread):
|
|||||||
# self.total_samples = 0 # 总采样点数
|
# self.total_samples = 0 # 总采样点数
|
||||||
# self.window_ms = 600 # 检测窗口大小 (ms)
|
# self.window_ms = 600 # 检测窗口大小 (ms)
|
||||||
# self.step_ms = 100 # 滑动步长 (ms)
|
# self.step_ms = 100 # 滑动步长 (ms)
|
||||||
# self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
|
# self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点
|
||||||
# self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
|
# self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点
|
||||||
# self.buffer_size = self.window_samples + self.step_samples * 5
|
# self.buffer_size = self.window_samples + self.step_samples * 5
|
||||||
# self.fp1_buffer = deque(maxlen=self.buffer_size)
|
# self.fp1_buffer = deque(maxlen=self.buffer_size)
|
||||||
# self.fp2_buffer = deque(maxlen=self.buffer_size)
|
# self.fp2_buffer = deque(maxlen=self.buffer_size)
|
||||||
@@ -150,11 +153,11 @@ class Decoder_main(threading.Thread):
|
|||||||
# self.double_blink_events = [] # 连续眨眼事件记录
|
# self.double_blink_events = [] # 连续眨眼事件记录
|
||||||
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
|
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
|
||||||
# self.blink_events = []
|
# self.blink_events = []
|
||||||
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
|
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band')
|
||||||
|
|
||||||
def parameter_init(self,bandPass_low,bandPass_high):
|
def parameter_init(self,bandPass_low,bandPass_high):
|
||||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
|
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息
|
||||||
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
|
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch
|
||||||
self.trainData = [] #训练数据
|
self.trainData = [] #训练数据
|
||||||
self.trainLabel = [] #训练标签
|
self.trainLabel = [] #训练标签
|
||||||
self.plotData = [] #报告分析数据
|
self.plotData = [] #报告分析数据
|
||||||
@@ -162,12 +165,12 @@ class Decoder_main(threading.Thread):
|
|||||||
self.currentLabel = -1 #刺激界面当前显示的训练标签
|
self.currentLabel = -1 #刺激界面当前显示的训练标签
|
||||||
self.train_started = False #是否开始训练模型
|
self.train_started = False #是否开始训练模型
|
||||||
self.load_model = False # 调用模型是否完成的标志
|
self.load_model = False # 调用模型是否完成的标志
|
||||||
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
||||||
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
||||||
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
|
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||||
|
filePath = './online_Models/'
|
||||||
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
|
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
|
||||||
os.remove(old_pth)
|
os.remove(old_pth)
|
||||||
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
|
||||||
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
||||||
self.mp_data_queue = mp.Queue()
|
self.mp_data_queue = mp.Queue()
|
||||||
self.mp_result_queue = mp.Queue()
|
self.mp_result_queue = mp.Queue()
|
||||||
@@ -184,13 +187,8 @@ class Decoder_main(threading.Thread):
|
|||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while self.Runing:
|
while self.Runing:
|
||||||
# 当滤波数据大于5秒时,启动滤波线程
|
|
||||||
if self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5:
|
|
||||||
algo_log("启动滤波线程", level="DEBUG")
|
|
||||||
self.sliding_filter.start()
|
|
||||||
|
|
||||||
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
|
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
|
||||||
algo_log(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}", level="DEBUG")
|
print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}")
|
||||||
self.zmqServer.decoder_switch = False
|
self.zmqServer.decoder_switch = False
|
||||||
self.zmqServer.changeTarget = False
|
self.zmqServer.changeTarget = False
|
||||||
self.reset_state() # 切换前先统一清理旧状态
|
self.reset_state() # 切换前先统一清理旧状态
|
||||||
@@ -198,9 +196,57 @@ class Decoder_main(threading.Thread):
|
|||||||
|
|
||||||
# 同步信息
|
# 同步信息
|
||||||
if self.zmqServer.state_mode == 'sync':
|
if self.zmqServer.state_mode == 'sync':
|
||||||
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||||
self.zmqServer.state_mode = 'rest'
|
self.zmqServer.state_mode = 'rest'
|
||||||
|
# 状态异常,报告上位机
|
||||||
|
if self.status_code != self.thread_data_server.status_code:
|
||||||
|
self.status_code = self.thread_data_server.status_code
|
||||||
|
self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||||
|
print('status code')
|
||||||
|
|
||||||
|
# 返回电量
|
||||||
|
if self.energy != self.thread_data_server.energy:
|
||||||
|
self.energy = self.thread_data_server.energy
|
||||||
|
self.zmqClient.send_to_all('energy', int(self.energy))
|
||||||
|
print('energy')
|
||||||
|
|
||||||
|
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||||
|
self.thread_data_server.Impedance(True)
|
||||||
|
print('Impedance')
|
||||||
|
self.zmqServer.open_Impedance = -1
|
||||||
|
elif self.zmqServer.open_Impedance == False:
|
||||||
|
self.thread_data_server.Impedance(False)
|
||||||
|
self.zmqServer.open_Impedance = -1
|
||||||
|
|
||||||
|
if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||||
|
# print(self.zmqServer.get_Impedance)
|
||||||
|
# print(self.thread_data_server.GetDataLenCount())
|
||||||
|
if self.thread_data_server.GetDataLenCount() > 250:
|
||||||
|
Impe_data = self.thread_data_server.getData(250)
|
||||||
|
# 计算阻抗
|
||||||
|
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
|
||||||
|
self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
if self.zmqServer.getReport: #返回训练报告内容
|
||||||
|
self.zmqServer.getReport = False
|
||||||
|
allData = np.array(self.plotData)
|
||||||
|
allLabel = np.array(self.plotLabel) + 1
|
||||||
|
nTrials = min(len(allLabel),len(allData))
|
||||||
|
if nTrials < 30:
|
||||||
|
self.zmqClient.send_to_all('miReport',0)
|
||||||
|
else:
|
||||||
|
allData = allData[:nTrials]
|
||||||
|
allLabel = allLabel[:nTrials]
|
||||||
|
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
||||||
|
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
||||||
|
compare_names = ['C3', 'CZ', 'C4']
|
||||||
|
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
|
||||||
|
fs=self.fs)
|
||||||
|
self.zmqClient.send_to_all('miReport',miReport)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 ---
|
||||||
try:
|
try:
|
||||||
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||||
self.decoder_SSVEP()
|
self.decoder_SSVEP()
|
||||||
@@ -208,28 +254,34 @@ class Decoder_main(threading.Thread):
|
|||||||
self.decoder_SSMVEP()
|
self.decoder_SSMVEP()
|
||||||
elif self.decoder_class == 'mi':
|
elif self.decoder_class == 'mi':
|
||||||
self.decoder_MI()
|
self.decoder_MI()
|
||||||
|
elif self.decoder_class == 'concentration':
|
||||||
|
self.decoder_concentration()
|
||||||
|
elif self.decoder_class == 'blink':
|
||||||
|
self.decoder_blink()
|
||||||
else:
|
else:
|
||||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
continue;
|
continue;
|
||||||
self.zmqServer.paradigmBuffer.getData(25)
|
self.thread_data_server.getData(25)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"Decoder Loop Error: {e}")
|
print(f"Decoder Loop Error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
||||||
|
|
||||||
def decoder_SSVEP(self):
|
def decoder_SSVEP(self):
|
||||||
if self.zmqServer.StartDecode:
|
if self.zmqServer.StartDecode:
|
||||||
self.zmqServer.StartDecode = False
|
self.zmqServer.StartDecode = False
|
||||||
self.decodingSteps = 1
|
self.decodingSteps = 1
|
||||||
self.zmqServer.paradigmBuffer.resetAllPara()
|
self.thread_data_server.ResetAll()
|
||||||
print('启动预测')
|
print('启动预测')
|
||||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
|
if self.thread_data_server.GetDataLenCount() < 50:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
return
|
return
|
||||||
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
|
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
||||||
return
|
return
|
||||||
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
|
data = self.thread_data_server.getDataViaSSVEP(50)
|
||||||
data = data[:self.n_chan, :]
|
data = data[:self.n_chan, :]
|
||||||
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
||||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
@@ -245,7 +297,7 @@ class Decoder_main(threading.Thread):
|
|||||||
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
|
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
|
||||||
self.calculateCount = 0
|
self.calculateCount = 0
|
||||||
if self.decodingSteps == 3: # 发送解码后的信息
|
if self.decodingSteps == 3: # 发送解码后的信息
|
||||||
self.zmqServer.broadcast_message('result', int(choosenNum))
|
self.zmqClient.send_to_all('result', int(choosenNum))
|
||||||
self.decodingSteps = 0
|
self.decodingSteps = 0
|
||||||
print('发送给界面完成。')
|
print('发送给界面完成。')
|
||||||
|
|
||||||
@@ -264,25 +316,25 @@ class Decoder_main(threading.Thread):
|
|||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('模型训练完成', formatted_time)
|
print('模型训练完成', formatted_time)
|
||||||
self.load_model = True
|
self.load_model = True
|
||||||
self.zmqServer.broadcast_message('paradigm', 1)
|
self.zmqClient.send_to_all('paradigm', 1)
|
||||||
|
|
||||||
'''训练阶段采集数据'''
|
'''训练阶段采集数据'''
|
||||||
if self.zmqServer.state_mode == 'train': # 训练状态
|
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||||
if self.zmqServer.StartTrain:
|
if self.zmqServer.StartTrain:
|
||||||
self.currentLabel = self.zmqServer.currentLabel
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
self.zmqServer.StartTrain = False
|
self.zmqServer.StartTrain = False
|
||||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
self.train_epoch[1] \
|
self.train_epoch[1] \
|
||||||
+ self.zmqServer.event_inner_idx:
|
+ self.thread_data_server.event_inner_idx:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
||||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
|
||||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
|
||||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[
|
||||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]]
|
||||||
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||||
self.trainLabel, list) \
|
self.trainLabel, list) \
|
||||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
@@ -300,39 +352,39 @@ class Decoder_main(threading.Thread):
|
|||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('启动预测 ', formatted_time)
|
print('启动预测 ', formatted_time)
|
||||||
|
|
||||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
self.interval_epoch[1] \
|
self.interval_epoch[1] \
|
||||||
+ self.zmqServer.event_inner_idx:
|
+ self.thread_data_server.event_inner_idx:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
|
data = self.thread_data_server.get_SSMVEPData() # 读取全部数据
|
||||||
print('取出的: ', data.shape, 'event: ', data[-2, self.zmqServer.event_inner_idx])
|
print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx])
|
||||||
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
||||||
data = data[:,
|
data = data[:,
|
||||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||||
pad_eeg_test = np.zeros(
|
pad_eeg_test = np.zeros(
|
||||||
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
|
(data.shape[0], int((self.sample_length + 0.1) * self.fs)))
|
||||||
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
|
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data
|
||||||
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
||||||
if isinstance(choosenNum, np.ndarray):
|
if isinstance(choosenNum, np.ndarray):
|
||||||
choosenNum = choosenNum[0]
|
choosenNum = choosenNum[0]
|
||||||
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
|
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
|
||||||
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
|
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
|
||||||
self.zmqServer.broadcast_message('result', int(choosenNum))
|
self.zmqClient.send_to_all('result', int(choosenNum))
|
||||||
print('发送给界面完成。')
|
print('发送给界面完成。')
|
||||||
else: # 休息状态
|
else: # 休息状态
|
||||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
return
|
return
|
||||||
self.zmqServer.paradigmBuffer.getData(25)
|
self.thread_data_server.getData(25)
|
||||||
|
|
||||||
def decoder_MI(self):
|
def decoder_MI(self):
|
||||||
'''模型训练'''
|
'''模型训练'''
|
||||||
if self.train_started == False and all(
|
if self.train_started == False and all(
|
||||||
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
|
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
|
||||||
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||||
self.train_started = True
|
self.train_started = True
|
||||||
self.trainData = np.array(self.trainData)
|
self.trainData = np.array(self.trainData)
|
||||||
self.trainLabel = np.array(self.trainLabel) + 1
|
self.trainLabel = np.array(self.trainLabel) + 1
|
||||||
@@ -358,7 +410,7 @@ class Decoder_main(threading.Thread):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = self.model(warmup_data)
|
_ = self.model(warmup_data)
|
||||||
self.load_model = True
|
self.load_model = True
|
||||||
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
|
self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机
|
||||||
else:
|
else:
|
||||||
print("训练失败:", result['msg'])
|
print("训练失败:", result['msg'])
|
||||||
except Empty:
|
except Empty:
|
||||||
@@ -371,26 +423,26 @@ class Decoder_main(threading.Thread):
|
|||||||
if self.zmqServer.StartTrain:
|
if self.zmqServer.StartTrain:
|
||||||
self.currentLabel = self.zmqServer.currentLabel
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
self.zmqServer.StartTrain = False
|
self.zmqServer.StartTrain = False
|
||||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
self.interval_epoch[1] \
|
self.interval_epoch[1] \
|
||||||
+ self.zmqServer.event_inner_idx:
|
+ self.thread_data_server.event_inner_idx:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
||||||
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
|
||||||
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.zmqServer.event_inner_idx])
|
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
|
||||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||||
print('trial: ', self.zmqServer.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
||||||
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||||
list) \
|
list) \
|
||||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
self.trainData.append(trainTrial)
|
self.trainData.append(trainTrial)
|
||||||
self.trainLabel.append(self.currentLabel)
|
self.trainLabel.append(self.currentLabel)
|
||||||
print('训练集:', np.shape(self.trainData))
|
print('训练集:', np.shape(self.trainData))
|
||||||
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||||
self.plotLabel.append(self.currentLabel)
|
self.plotLabel.append(self.currentLabel)
|
||||||
|
|
||||||
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
||||||
@@ -400,21 +452,21 @@ class Decoder_main(threading.Thread):
|
|||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('启动预测 ', formatted_time)
|
print('启动预测 ', formatted_time)
|
||||||
|
|
||||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
self.interval_epoch[1] \
|
self.interval_epoch[1] \
|
||||||
+ self.zmqServer.event_inner_idx:
|
+ self.thread_data_server.event_inner_idx:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
|
originalData = self.thread_data_server.get_MIData() # 读取全部数据
|
||||||
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.zmqServer.event_inner_idx])
|
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx])
|
||||||
start = time.time()
|
start = time.time()
|
||||||
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
||||||
data = data[:,
|
data = data[:,
|
||||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||||
self.plotData.append(
|
self.plotData.append(
|
||||||
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||||
|
|
||||||
test_data = data[np.newaxis, np.newaxis, :, :]
|
test_data = data[np.newaxis, np.newaxis, :, :]
|
||||||
test_data = torch.from_numpy(test_data)
|
test_data = torch.from_numpy(test_data)
|
||||||
@@ -424,40 +476,133 @@ class Decoder_main(threading.Thread):
|
|||||||
y_pred = torch.max(Cls, 1)[1]
|
y_pred = torch.max(Cls, 1)[1]
|
||||||
self.plotLabel.append(int(y_pred.item()))
|
self.plotLabel.append(int(y_pred.item()))
|
||||||
print('运动意图识别: ', y_pred)
|
print('运动意图识别: ', y_pred)
|
||||||
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
|
self.zmqClient.send_to_all('result', int(y_pred.item()))
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
||||||
else: # 休息状态
|
else: # 休息状态
|
||||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
return
|
return
|
||||||
self.zmqServer.paradigmBuffer.getData(25)
|
self.thread_data_server.getData(25)
|
||||||
|
|
||||||
# def decoder_concentration(self):
|
def decoder_concentration(self):
|
||||||
# if self.zmqServer.state_mode == 'predict':
|
if self.zmqServer.state_mode == 'predict':
|
||||||
# if self.zmqServer.StartDecode:
|
if self.zmqServer.StartDecode:
|
||||||
# self.zmqServer.StartDecode = False
|
self.zmqServer.StartDecode = False
|
||||||
# self.thread_data_server.ResetAll()
|
self.thread_data_server.ResetAll()
|
||||||
# now = datetime.now()
|
now = datetime.now()
|
||||||
# formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
# print('启动专注力预测 ', formatted_time)
|
print('启动专注力预测 ', formatted_time)
|
||||||
# if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
|
if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果
|
||||||
# time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
# return
|
return
|
||||||
# if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
||||||
# return
|
return
|
||||||
# data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
|
data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据
|
||||||
# result = self.calculate.queueOpt(data)
|
result = self.calculate.queueOpt(data)
|
||||||
# if result is not None:
|
if result is not None:
|
||||||
# self.zmqClient.send_to_all('result', int(result))
|
self.zmqClient.send_to_all('result', int(result))
|
||||||
# else: # 休息状态
|
else: # 休息状态
|
||||||
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
# if self.thread_data_server.GetDataLenCount() < 25:
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
# time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
# return
|
return
|
||||||
# self.thread_data_server.getData(25)
|
self.thread_data_server.getData(25)
|
||||||
|
|
||||||
|
#### Blink detection #####
|
||||||
|
def check_double_blink(self, current_time):
|
||||||
|
"""
|
||||||
|
检查是否检测到连续两次眨眼
|
||||||
|
@param current_time: 当前眨眼时间戳
|
||||||
|
@return: True表示检测到连续两次眨眼
|
||||||
|
"""
|
||||||
|
if len(self.blink_timestamps) < 2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否在去抖期内
|
||||||
|
if self.last_double_blink_time > 0:
|
||||||
|
time_since_last_double_blink = current_time - self.last_double_blink_time
|
||||||
|
if time_since_last_double_blink < self.double_blink_jitter:
|
||||||
|
return False # 在去抖期内,忽略连续眨眼检测
|
||||||
|
last_time = self.blink_timestamps[-1] # 当前眨眼
|
||||||
|
prev_time = self.blink_timestamps[-2] # 上次眨眼
|
||||||
|
|
||||||
|
interval = last_time - prev_time
|
||||||
|
if interval <= self.double_blink_interval:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def process_blink_detection(self):
|
||||||
|
"""
|
||||||
|
在缓冲区数据上执行,单次眨眼检测
|
||||||
|
"""
|
||||||
|
if len(self.fp1_buffer) < self.window_samples:
|
||||||
|
return
|
||||||
|
|
||||||
|
fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
|
||||||
|
fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
|
||||||
|
# 计算FP1和FP2的平均
|
||||||
|
fp12_mean = (fp1_data + fp2_data) / 2.0
|
||||||
|
# 带通滤波
|
||||||
|
try:
|
||||||
|
fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Filter error: {e}")
|
||||||
|
return
|
||||||
|
F = np.diff(fp12_filtered)
|
||||||
|
if len(F) < 3:
|
||||||
|
return
|
||||||
|
b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax)
|
||||||
|
|
||||||
|
if b == 1:
|
||||||
|
samples_since_last = self.total_samples - self.last_blink_time
|
||||||
|
time_since_last_ms = (samples_since_last / self.fs) * 1000
|
||||||
|
if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
|
||||||
|
self.blink_count += 1
|
||||||
|
self.last_blink_time = self.total_samples
|
||||||
|
current_time = time.time()
|
||||||
|
self.blink_timestamps.append(current_time)
|
||||||
|
blink_event = {
|
||||||
|
'count': self.blink_count,
|
||||||
|
'time': current_time,
|
||||||
|
'sample_index': self.total_samples,
|
||||||
|
'duration_ms': d,
|
||||||
|
'energy': e
|
||||||
|
}
|
||||||
|
self.blink_events.append(blink_event)
|
||||||
|
self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
|
||||||
|
if self.check_double_blink(current_time):
|
||||||
|
self.double_blink_count += 1
|
||||||
|
interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
|
||||||
|
double_blink_event = {
|
||||||
|
'double_blink_count': self.double_blink_count,
|
||||||
|
'blink1_time': self.blink_timestamps[-2],
|
||||||
|
'blink2_time': self.blink_timestamps[-1],
|
||||||
|
'interval': interval
|
||||||
|
}
|
||||||
|
self.double_blink_events.append(double_blink_event)
|
||||||
|
self.last_double_blink_time = current_time
|
||||||
|
self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
|
||||||
|
|
||||||
|
def decoder_blink(self):
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 50:
|
||||||
|
time.sleep(0.005)
|
||||||
|
return
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
data = self.thread_data_server.get_blinkData(50)
|
||||||
|
fp1_data = data[0, :] # ch1 (相当于FP1)
|
||||||
|
fp2_data = data[1, :] # ch2 (相当于FP2)
|
||||||
|
for i in range(len(fp1_data)):
|
||||||
|
self.fp1_buffer.append(fp1_data[i])
|
||||||
|
self.fp2_buffer.append(fp2_data[i])
|
||||||
|
self.total_samples += 1
|
||||||
|
self.sample_counter += 1
|
||||||
|
|
||||||
|
if self.sample_counter >= self.step_samples:
|
||||||
|
self.process_blink_detection()
|
||||||
|
self.sample_counter = 0
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
'''
|
'''
|
||||||
@@ -465,13 +610,12 @@ class Decoder_main(threading.Thread):
|
|||||||
@return:
|
@return:
|
||||||
'''
|
'''
|
||||||
self.zmqServer.stop()
|
self.zmqServer.stop()
|
||||||
self.sliding_filter.stop()
|
|
||||||
self.Runing=False
|
self.Runing=False
|
||||||
|
|
||||||
def reset_state(self):
|
def reset_state(self):
|
||||||
"""清空解码器状态和缓存数据"""
|
"""清空解码器状态和缓存数据"""
|
||||||
# 重置设备层缓存
|
# 重置设备层缓存
|
||||||
self.zmqServer.reset_state()
|
self.thread_data_server.reset_state()
|
||||||
|
|
||||||
# 重置解码状态
|
# 重置解码状态
|
||||||
self.decodingSteps = 0
|
self.decodingSteps = 0
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
fill_value = torch.finfo(torch.float64).min
|
fill_value = torch.finfo(torch.float32).min
|
||||||
energy.mask_fill(~mask, fill_value)
|
energy.mask_fill(~mask, fill_value)
|
||||||
|
|
||||||
scaling = self.emb_size ** (1 / 2)
|
scaling = self.emb_size ** (1 / 2)
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
fill_value = torch.finfo(torch.float64).min
|
fill_value = torch.finfo(torch.float32).min
|
||||||
energy.mask_fill(~mask, fill_value)
|
energy.mask_fill(~mask, fill_value)
|
||||||
|
|
||||||
scaling = self.emb_size ** (1 / 2)
|
scaling = self.emb_size ** (1 / 2)
|
||||||
|
|||||||
@@ -13,9 +13,5 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and
|
|||||||
6. decoder class切换问题
|
6. decoder class切换问题
|
||||||
7. decoder_class切换时,数据重置、各类参数重置
|
7. decoder_class切换时,数据重置、各类参数重置
|
||||||
|
|
||||||
|
# update
|
||||||
# 常用命令
|
2026年6月5日13:55:34
|
||||||
source activate 3in1Py310
|
|
||||||
python runDecoder.py
|
|
||||||
python datamock.py
|
|
||||||
python ZeroMQClient_mock.py
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
from scipy.signal import welch
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Beta_Calculate():
|
|
||||||
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=5, config=None):
|
|
||||||
self.Threshold_value_low = Threshold_value_low
|
|
||||||
self.Threshold_value_high = Threshold_value_high
|
|
||||||
self.fs = fs
|
|
||||||
self.beta_result = []
|
|
||||||
self.eegQueue = deque(maxlen=win_len)
|
|
||||||
|
|
||||||
def calculate_all(self, data, fs, nperseg=1000):
|
|
||||||
mean_x = np.mean(data, axis=-1, keepdims=True)
|
|
||||||
data = data - mean_x
|
|
||||||
freqs, psd = self.compute_psd_multichannel(data, fs, nperseg)
|
|
||||||
beta_psd = np.sum(self.band_psd(freqs, psd, (13, 30)))
|
|
||||||
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
|
|
||||||
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
|
|
||||||
|
|
||||||
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
|
|
||||||
|
|
||||||
return beta_psd, alpha_psd, theta_psd
|
|
||||||
|
|
||||||
def compute_psd_multichannel(self, data, fs=250, nperseg=1000):
|
|
||||||
n_samples = data.shape[-1]
|
|
||||||
if n_samples < nperseg:
|
|
||||||
nperseg = n_samples
|
|
||||||
|
|
||||||
noverlap = 500
|
|
||||||
if noverlap >= nperseg:
|
|
||||||
noverlap = int(nperseg / 2)
|
|
||||||
|
|
||||||
if nperseg == 0:
|
|
||||||
return np.array([]), np.zeros((data.shape[0], 0))
|
|
||||||
|
|
||||||
freqs, psd = welch(data, fs=fs, nperseg=nperseg, noverlap=noverlap, axis=-1)
|
|
||||||
return freqs, psd
|
|
||||||
|
|
||||||
def band_psd(self, freqs, psd, band):
|
|
||||||
idx = np.logical_and(freqs >= band[0], freqs <= band[1])
|
|
||||||
return np.sum(psd[:, idx], axis=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def reset_queue(self):
|
|
||||||
self.eegQueue.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def queueOpt(self, data):
|
|
||||||
if data is None or data.size == 0:
|
|
||||||
return None
|
|
||||||
if len(self.eegQueue) < self.eegQueue.maxlen:
|
|
||||||
self.eegQueue.append(data)
|
|
||||||
else:
|
|
||||||
self.eegQueue.append(data)
|
|
||||||
|
|
||||||
if len(self.eegQueue) == self.eegQueue.maxlen:
|
|
||||||
eegData = np.hstack([self.eegQueue[i] for i in range(len(self.eegQueue))])
|
|
||||||
if eegData.size == 0:
|
|
||||||
return None
|
|
||||||
eegData -= np.mean(eegData, axis=-1, keepdims=True)
|
|
||||||
|
|
||||||
beta_psd, alpha_psd, theta_psd = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
|
|
||||||
|
|
||||||
return (beta_psd)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,166 +0,0 @@
|
|||||||
import zmq
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
|
|
||||||
def receive_messages(socket, stop_event):
|
|
||||||
"""
|
|
||||||
后台线程函数,用于持续接收服务器消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
socket (zmq.Socket): ZeroMQ套接字
|
|
||||||
stop_event (threading.Event): 停止事件,用于通知线程退出
|
|
||||||
"""
|
|
||||||
print("开始持续接收服务器数据...")
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
while not stop_event.is_set():
|
|
||||||
try:
|
|
||||||
# 设置接收超时为1秒,避免阻塞
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 1000)
|
|
||||||
# 接收服务器的消息
|
|
||||||
frames = socket.recv_multipart()
|
|
||||||
|
|
||||||
# DEALER 套接字接收消息格式:[身份标识, 空帧, 消息内容]
|
|
||||||
# 使用frames[-1]获取最后一帧,无论中间有多少空帧
|
|
||||||
if len(frames) >= 2:
|
|
||||||
message = frames[-1].decode('utf-8')
|
|
||||||
|
|
||||||
# 尝试解析为JSON格式
|
|
||||||
try:
|
|
||||||
json_message = json.loads(message)
|
|
||||||
# 检查消息长度
|
|
||||||
json_str = str(json_message)
|
|
||||||
if len(json_str) > 100:
|
|
||||||
print(f"收到服务器数据 (JSON): {json_str[:100]}...")
|
|
||||||
else:
|
|
||||||
print(f"收到服务器数据 (JSON): {json_message}")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# 检查消息长度
|
|
||||||
if len(message) > 100:
|
|
||||||
print(f"收到服务器数据 (原始): {message[:100]}...")
|
|
||||||
else:
|
|
||||||
print(f"收到服务器数据 (原始): {message}")
|
|
||||||
else:
|
|
||||||
print(f"收到服务器数据 (格式异常): {frames}")
|
|
||||||
|
|
||||||
except zmq.Again:
|
|
||||||
# 接收超时,继续循环
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
print(f"接收消息时发生错误: {e}")
|
|
||||||
# 短暂暂停后继续接收
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
print("接收线程已停止。")
|
|
||||||
|
|
||||||
def zero_mq_client(server_address="tcp://127.0.0.1:8099"):
|
|
||||||
"""
|
|
||||||
ZeroMQ客户端函数,用于与服务器通信
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server_address (str): 服务器地址,格式为"tcp://IP:端口"
|
|
||||||
"""
|
|
||||||
# 创建 ZeroMQ 上下文
|
|
||||||
context = zmq.Context()
|
|
||||||
|
|
||||||
# 创建 DEALER 套接字
|
|
||||||
socket = context.socket(zmq.DEALER)
|
|
||||||
|
|
||||||
# 生成唯一的身份标识
|
|
||||||
identity = str('wdd').encode('utf-8')
|
|
||||||
socket.setsockopt(zmq.IDENTITY, identity)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 连接到服务器
|
|
||||||
print(f"连接到服务器 {server_address}...")
|
|
||||||
socket.connect(server_address)
|
|
||||||
|
|
||||||
# 定义消息集
|
|
||||||
message_set = [
|
|
||||||
{"method": "sync", "params": 1},
|
|
||||||
{"method": "decoderClass", "params": "mi"},
|
|
||||||
{"method": "decoderClass", "params": "ssvep"},
|
|
||||||
{"method": "decoderClass", "params": "ssmvep"},
|
|
||||||
{"method": "decoderClass", "params": "blink"},
|
|
||||||
{"method": "decoderClass", "params": "concentration"},
|
|
||||||
{"method": "train", "params": 0},
|
|
||||||
{"method": "train", "params": 1},
|
|
||||||
{"method": "rest", "params": 0},
|
|
||||||
{"method": "predict", "params": 1},
|
|
||||||
{"method": "getReport", "params": 0}
|
|
||||||
]
|
|
||||||
|
|
||||||
# 打印消息集
|
|
||||||
print("消息集:")
|
|
||||||
for i, msg in enumerate(message_set):
|
|
||||||
print(f"[{i}] {msg}")
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
# 创建停止事件
|
|
||||||
stop_event = threading.Event()
|
|
||||||
|
|
||||||
# 启动接收线程
|
|
||||||
receive_thread = threading.Thread(target=receive_messages, args=(socket, stop_event))
|
|
||||||
receive_thread.daemon = True # 设置为守护线程,主线程退出时自动退出
|
|
||||||
receive_thread.start()
|
|
||||||
|
|
||||||
# 主线程处理控制台输入
|
|
||||||
print("输入消息序号发送对应消息,输入'q'退出程序:")
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# 获取用户输入
|
|
||||||
user_input = input("请输入消息序号: ")
|
|
||||||
|
|
||||||
# 检查是否退出
|
|
||||||
if user_input.lower() == 'q':
|
|
||||||
print("正在退出程序...")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 尝试转换为整数
|
|
||||||
msg_index = int(user_input)
|
|
||||||
|
|
||||||
# 检查序号是否有效
|
|
||||||
if 0 <= msg_index < len(message_set):
|
|
||||||
# 获取对应的消息
|
|
||||||
selected_message = message_set[msg_index]
|
|
||||||
|
|
||||||
# 将消息转换为 JSON 字符串
|
|
||||||
json_message = json.dumps(selected_message)
|
|
||||||
|
|
||||||
# 打印发送信息
|
|
||||||
print(f"\n发送消息 (大小: {len(json_message)} 字节)...")
|
|
||||||
print(f"消息方法: {selected_message['method']}")
|
|
||||||
print(f"参数值: {selected_message['params']}")
|
|
||||||
|
|
||||||
# DEALER 套接字发送消息,包含身份标识和空帧
|
|
||||||
socket.send_multipart([identity, json_message.encode('utf-8')])
|
|
||||||
print("消息发送完成!")
|
|
||||||
print("-" * 50)
|
|
||||||
else:
|
|
||||||
print(f"无效的消息序号,请输入 0-{len(message_set)-1} 之间的数字。")
|
|
||||||
print("消息集:")
|
|
||||||
for i, msg in enumerate(message_set):
|
|
||||||
print(f"[{i}] {msg}")
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
print("请输入有效的数字或'q'退出。")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"处理输入时发生错误: {e}")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n程序被手动终止。")
|
|
||||||
finally:
|
|
||||||
# 停止接收线程
|
|
||||||
stop_event.set()
|
|
||||||
# 等待接收线程停止
|
|
||||||
time.sleep(1)
|
|
||||||
# 关闭套接字和上下文
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
print("客户端已关闭。")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
zero_mq_client()
|
|
||||||
@@ -5,13 +5,12 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
import threading
|
import threading
|
||||||
from logs.log import algo_log
|
|
||||||
|
|
||||||
class ParadigmRingBuffer:
|
class ParadigmRingBuffer:
|
||||||
def __init__(self, n_chan, n_points):
|
def __init__(self, n_chan, n_points):
|
||||||
self.n_chan = n_chan
|
self.n_chan = n_chan
|
||||||
self.n_points = n_points
|
self.n_points = n_points
|
||||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
self.buffer = np.zeros((n_chan, n_points))
|
||||||
self.currentPtr = 0
|
self.currentPtr = 0
|
||||||
self.readPtr = 0
|
self.readPtr = 0
|
||||||
self.nUpdate = 0
|
self.nUpdate = 0
|
||||||
@@ -20,8 +19,7 @@ class ParadigmRingBuffer:
|
|||||||
## append buffer and update current pointer
|
## append buffer and update current pointer
|
||||||
def appendBuffer(self, data):
|
def appendBuffer(self, data):
|
||||||
if self.nUpdate == self.n_points:
|
if self.nUpdate == self.n_points:
|
||||||
# raise Exception("Buffer is full")
|
raise Exception("Buffer is full")
|
||||||
algo_log("Buffer is full", record_once=True)
|
|
||||||
|
|
||||||
n = data.shape[1]
|
n = data.shape[1]
|
||||||
|
|
||||||
@@ -67,56 +65,13 @@ class ParadigmRingBuffer:
|
|||||||
'''
|
'''
|
||||||
return self.nUpdate
|
return self.nUpdate
|
||||||
|
|
||||||
# ========== 各范式数据访问接口 ==========
|
|
||||||
def get_MIData(self):
|
|
||||||
"""获取MI导联数据 (21通道 + 事件)"""
|
|
||||||
data = self.getData(self.GetDataLenCount())
|
|
||||||
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
|
|
||||||
row_to_select = np.array(rows_to_extract)
|
|
||||||
if data.shape[1] > 0:
|
|
||||||
return data[row_to_select, :]
|
|
||||||
return np.zeros((len(rows_to_extract), 0))
|
|
||||||
|
|
||||||
def get_SSMVEPData(self):
|
|
||||||
"""获取SSMVEP导联数据 (8通道 + 事件)"""
|
|
||||||
data = self.getData(self.GetDataLenCount())
|
|
||||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
|
|
||||||
row_to_select = np.array(rows_to_extract)
|
|
||||||
if data.shape[1] > 0:
|
|
||||||
return data[row_to_select, :]
|
|
||||||
return np.zeros((len(rows_to_extract), 0))
|
|
||||||
|
|
||||||
def getDataViaSSVEP(self, count):
|
|
||||||
"""获取SSVEP数据 (8通道 + 事件)"""
|
|
||||||
data = self.getData(count)
|
|
||||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
|
|
||||||
row_to_select = np.array(rows_to_extract)
|
|
||||||
if data.shape[1] > 0:
|
|
||||||
return data[row_to_select, :]
|
|
||||||
return np.zeros((len(rows_to_extract), 0))
|
|
||||||
|
|
||||||
def get_concentrateData(self, count):
|
|
||||||
"""获取专注力数据 (2通道)"""
|
|
||||||
data = self.getData(count)
|
|
||||||
rows_to_extract = [0, 1]
|
|
||||||
row_to_select = np.array(rows_to_extract)
|
|
||||||
if data.shape[1] > 0:
|
|
||||||
return data[row_to_select, :]
|
|
||||||
return np.zeros((len(rows_to_extract), 0))
|
|
||||||
|
|
||||||
def get_blinkData(self, count):
|
|
||||||
"""获取眨眼数据 (2通道)"""
|
|
||||||
data = self.getData(count)
|
|
||||||
rows_to_extract = [0, 1]
|
|
||||||
row_to_select = np.array(rows_to_extract)
|
|
||||||
if data.shape[1] > 0:
|
|
||||||
return data[row_to_select, :]
|
|
||||||
return np.zeros((len(rows_to_extract), 0))
|
|
||||||
|
|
||||||
# reset buffer
|
# reset buffer
|
||||||
def resetAllPara(self):
|
def resetAllPara(self):
|
||||||
self.nUpdate = 0
|
self.nUpdate = 0
|
||||||
self.currentPtr = 0
|
self.currentPtr = 0
|
||||||
self.readPtr = 0
|
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||||
self.buffer.fill(0.0)
|
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,15 +3,13 @@
|
|||||||
数据滤波模块
|
数据滤波模块
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
|
||||||
import threading
|
import threading
|
||||||
from scipy import signal
|
|
||||||
from logs.log import algo_log
|
from logs.log import algo_log
|
||||||
|
|
||||||
class FilterRingBuffer:
|
class FilterRingBuffer:
|
||||||
def __init__(self, n_chan, n_points):
|
def __init__(self, n_chan, n_points):
|
||||||
"""
|
"""
|
||||||
初始化纯数据环形缓存(线程安全)
|
初始化纯数据环形缓存
|
||||||
:param n_chan: 通道数
|
:param n_chan: 通道数
|
||||||
:param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致)
|
:param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致)
|
||||||
"""
|
"""
|
||||||
@@ -19,9 +17,11 @@ class FilterRingBuffer:
|
|||||||
self.n_points = n_points
|
self.n_points = n_points
|
||||||
|
|
||||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||||
self.current_ptr = 0 # 写入指针:指向下一个要写入的位置
|
self.current_ptr = 0 # 写入指针
|
||||||
self.total_samples = 0 # 已写入总点数
|
self.total_samples = 0 # 已写入总点数
|
||||||
self.lock = threading.Lock() # 线程安全锁
|
|
||||||
|
# 线程安全锁(多线程环境必须)
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
def appendBuffer(self, data):
|
def appendBuffer(self, data):
|
||||||
"""
|
"""
|
||||||
@@ -33,7 +33,7 @@ class FilterRingBuffer:
|
|||||||
if n == 0:
|
if n == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 环形写入逻辑:指针到末尾则绕回
|
# 环形写入逻辑
|
||||||
write_end = self.current_ptr + n
|
write_end = self.current_ptr + n
|
||||||
if write_end <= self.n_points:
|
if write_end <= self.n_points:
|
||||||
self.buffer[:, self.current_ptr:write_end] = data
|
self.buffer[:, self.current_ptr:write_end] = data
|
||||||
@@ -42,14 +42,13 @@ class FilterRingBuffer:
|
|||||||
self.buffer[:, self.current_ptr:] = data[:, :split]
|
self.buffer[:, self.current_ptr:] = data[:, :split]
|
||||||
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
||||||
|
|
||||||
# 更新指针(取模保证环形)和计数(不超过缓存总长度)
|
# 更新指针和计数
|
||||||
self.current_ptr = write_end % self.n_points
|
self.current_ptr = write_end % self.n_points
|
||||||
self.total_samples = min(self.total_samples + n, self.n_points)
|
self.total_samples = min(self.total_samples + n, self.n_points)
|
||||||
|
|
||||||
def getData(self, count):
|
def getData(self, count):
|
||||||
"""
|
"""
|
||||||
从最新位置向前读取count个点(环形读取)
|
从读指针位置读取count个点(与paradigmRingBuffer接口一致)
|
||||||
核心逻辑:current_ptr是下一个写入位置 → 最新数据在current_ptr之前
|
|
||||||
:param count: 读取点数
|
:param count: 读取点数
|
||||||
:return: np.ndarray, shape=(n_chan, count)
|
:return: np.ndarray, shape=(n_chan, count)
|
||||||
"""
|
"""
|
||||||
@@ -58,14 +57,13 @@ class FilterRingBuffer:
|
|||||||
if count == 0:
|
if count == 0:
|
||||||
return np.zeros((self.n_chan, 0))
|
return np.zeros((self.n_chan, 0))
|
||||||
|
|
||||||
# 环形读取:end是当前写入指针(最新数据的下一位),start是end - count
|
# 环形读取逻辑(与paradigmRingBuffer完全相同)
|
||||||
end = self.current_ptr
|
end = self.current_ptr
|
||||||
start = end - count
|
start = end - count
|
||||||
if start >= 0:
|
if start >= 0:
|
||||||
return self.buffer[:, start:end].copy()
|
return self.buffer[:, start:end].copy()
|
||||||
else:
|
else:
|
||||||
# 跨环形边界:前半部分从缓存末尾取,后半部分从开头取
|
part1 = self.buffer[:, start:]
|
||||||
part1 = self.buffer[:, start:] # start为负,等价于n_points + start
|
|
||||||
part2 = self.buffer[:, :end]
|
part2 = self.buffer[:, :end]
|
||||||
return np.concatenate((part1, part2), axis=1)
|
return np.concatenate((part1, part2), axis=1)
|
||||||
|
|
||||||
@@ -73,7 +71,7 @@ class FilterRingBuffer:
|
|||||||
"""
|
"""
|
||||||
扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口)
|
扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口)
|
||||||
:param n: 点数
|
:param n: 点数
|
||||||
:return: np.ndarray, shape=(n_chan, n) | None(数据不足时)
|
:return: np.ndarray, shape=(n_chan, n)
|
||||||
"""
|
"""
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if self.total_samples < n:
|
if self.total_samples < n:
|
||||||
@@ -94,35 +92,43 @@ class FilterRingBuffer:
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||||||
|
# 可替换任意缓存实现,只要实现appendBuffer、get_latest_n_points接口
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class SlidingFilter(threading.Thread):
|
class SlidingFilter:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ring_buffer: FilterRingBuffer,
|
|
||||||
n_chan=66,
|
n_chan=66,
|
||||||
srate=250,
|
srate=250,
|
||||||
|
buffer_sec=5,
|
||||||
window_sec=3,
|
window_sec=3,
|
||||||
step_sec=0.2
|
step_sec=0.2,
|
||||||
|
packet_size=5
|
||||||
):
|
):
|
||||||
super().__init__(daemon=True)
|
"""
|
||||||
|
初始化滑动滤波器
|
||||||
|
:param n_chan: 通道数
|
||||||
|
:param srate: 采样率
|
||||||
|
:param buffer_sec: 总缓存时长(秒)
|
||||||
|
:param window_sec: 滤波窗口时长(秒)
|
||||||
|
:param step_sec: 滑动步长/输出时长(秒)
|
||||||
|
:param packet_size: 每包数据点数(20ms一包=5点)
|
||||||
|
"""
|
||||||
# 核心参数
|
# 核心参数
|
||||||
self.n_chan = n_chan
|
self.n_chan = n_chan
|
||||||
self.srate = srate
|
self.srate = srate
|
||||||
self.step_sec = step_sec # 200ms滑动步长
|
self.buffer_size = int(srate * buffer_sec)
|
||||||
self.window_sec = window_sec # 3秒窗口
|
self.window_size = int(srate * window_sec)
|
||||||
self.step_sec = step_sec # 200ms滑动步长
|
self.step_size = int(srate * step_sec)
|
||||||
self.window_size = int(srate * window_sec) # 3秒点数:250*3=750
|
self.packet_size = packet_size
|
||||||
self.step_size = int(srate * step_sec) # 200ms点数:250*0.2=50
|
|
||||||
|
|
||||||
# 关联ZMQServer的环形缓存(解耦:仅依赖接口)
|
# 初始化纯数据缓存(解耦核心)
|
||||||
self.ring_buffer = ring_buffer
|
self.buffer = FilterRingBuffer(n_chan, self.buffer_size)
|
||||||
# 线程控制
|
|
||||||
self.running = threading.Event()
|
|
||||||
self.running.set()
|
|
||||||
# 滤波结果回调(外部可注册,获取滤波后的数据)
|
|
||||||
self.filter_result_callback = None
|
|
||||||
|
|
||||||
# 预计算滤波器系数(仅执行一次)
|
# 滤波触发计数器
|
||||||
|
self.packet_count = 0
|
||||||
|
self.ready_to_filter = False
|
||||||
|
|
||||||
|
# 预计算滤波器系数
|
||||||
self._init_filters()
|
self._init_filters()
|
||||||
|
|
||||||
def _init_filters(self):
|
def _init_filters(self):
|
||||||
@@ -138,71 +144,65 @@ class SlidingFilter(threading.Thread):
|
|||||||
)
|
)
|
||||||
self.a_bp = np.array([1.0])
|
self.a_bp = np.array([1.0])
|
||||||
|
|
||||||
def _filter_window_data(self, window_data):
|
def append_and_check_trigger(self, raw_data):
|
||||||
"""对3秒窗口数据执行滤波,返回无边界效应的200ms数据"""
|
"""
|
||||||
|
追加单包原始数据并检查是否触发滤波
|
||||||
|
:param raw_data: 上位机原始数据,shape=(packet_size, n_chan)
|
||||||
|
:return: bool: 是否触发本次滤波
|
||||||
|
"""
|
||||||
|
# 转置为标准格式:(通道数, 点数)
|
||||||
|
data = raw_data.T.astype(np.float64)
|
||||||
|
|
||||||
|
# 写入缓存(纯缓存操作)
|
||||||
|
self.buffer.appendBuffer(data)
|
||||||
|
|
||||||
|
# 更新包计数器
|
||||||
|
self.packet_count += 1
|
||||||
|
|
||||||
|
# 检查滤波触发条件:数据≥窗口长度 且 累计满一个步长的包数
|
||||||
|
packets_per_step = int(self.step_size / self.packet_size) # 10包=200ms
|
||||||
|
if (self.buffer.GetDataLenCount() >= self.window_size
|
||||||
|
and self.packet_count >= packets_per_step):
|
||||||
|
self.packet_count = 0
|
||||||
|
self.ready_to_filter = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def filter_and_get_output(self):
|
||||||
|
"""
|
||||||
|
执行滤波并返回无边界效应的输出数据
|
||||||
|
:return: np.ndarray: 滤波后数据,shape=(n_chan, step_size)
|
||||||
|
"""
|
||||||
|
if not self.ready_to_filter:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取最新的完整滤波窗口数据
|
||||||
|
window_data = self.buffer.get_latest_n_points(self.window_size)
|
||||||
|
if window_data is None:
|
||||||
|
self.ready_to_filter = False
|
||||||
|
return None
|
||||||
|
|
||||||
# 零相位滤波(无延迟,无边界效应)
|
# 零相位滤波(无延迟,无边界效应)
|
||||||
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
||||||
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
||||||
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
|
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
|
||||||
|
|
||||||
# 提取倒数第二个200ms的数据(完全避开两端边界效应)
|
# 提取倒数第二个步长的数据(完全避开两端边界效应)
|
||||||
# 窗口长度750,步长50 → start=750-100=650,end=750-50=700
|
|
||||||
start_idx = self.window_size - 2 * self.step_size
|
start_idx = self.window_size - 2 * self.step_size
|
||||||
end_idx = self.window_size - self.step_size
|
end_idx = self.window_size - self.step_size
|
||||||
output_data = filtered[:, start_idx:end_idx].copy()
|
output_data = filtered[:, start_idx:end_idx].copy()
|
||||||
|
|
||||||
|
# 重置触发标志
|
||||||
|
self.ready_to_filter = False
|
||||||
|
|
||||||
return output_data
|
return output_data
|
||||||
|
|
||||||
def run(self):
|
def reset(self):
|
||||||
"""线程主逻辑:精确200ms触发一次滤波"""
|
"""重置滤波器和缓存"""
|
||||||
# 精确定时核心:基于perf_counter计算下一次执行时间,补偿sleep误差
|
self.buffer.resetAllPara()
|
||||||
interval = self.step_sec # 200ms = 0.2秒
|
self.packet_count = 0
|
||||||
next_run_time = time.perf_counter()
|
self.ready_to_filter = False
|
||||||
|
|
||||||
while self.running.is_set():
|
def get_buffer_length(self):
|
||||||
# 1. 等待到下一次执行时间(精确定时)
|
"""获取当前缓存数据长度"""
|
||||||
current_time = time.perf_counter()
|
return self.buffer.GetDataLenCount()
|
||||||
if current_time < next_run_time:
|
|
||||||
time.sleep(next_run_time - current_time)
|
|
||||||
next_run_time += interval # 补偿:下次执行时间基于上一次目标时间
|
|
||||||
else:
|
|
||||||
# 若超时(如滤波耗时超过200ms),重置下一次时间(避免累积误差)
|
|
||||||
algo_log("滤波耗时超过200ms,定时偏移", level='debug')
|
|
||||||
next_run_time = time.perf_counter() + interval
|
|
||||||
|
|
||||||
# 2. 执行滤波逻辑
|
|
||||||
try:
|
|
||||||
# 获取最新的3秒窗口数据
|
|
||||||
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
|
|
||||||
if window_data is None:
|
|
||||||
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}点", level='debug')
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 滤波并提取无边界效应的200ms数据
|
|
||||||
filtered_data = self._filter_window_data(window_data)
|
|
||||||
|
|
||||||
# 回调返回结果(外部可处理)
|
|
||||||
if self.filter_result_callback is not None:
|
|
||||||
self.filter_result_callback(filtered_data[:64, :]) # 只发送前64通道数据
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
algo_log(f"滤波执行异常: {e}", level='error')
|
|
||||||
|
|
||||||
def set_result_callback(self, callback):
|
|
||||||
"""注册滤波结果回调函数"""
|
|
||||||
self.filter_result_callback = callback
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""停止滤波线程(安全版)"""
|
|
||||||
# 1. 先设置停止标志(Event.clear()是线程安全的)
|
|
||||||
self.running.clear()
|
|
||||||
|
|
||||||
# 2. 核心修复:只有线程已启动且正在运行时才调用join
|
|
||||||
if self.is_alive():
|
|
||||||
# 等待线程正常退出,最多1秒
|
|
||||||
self.join(timeout=1)
|
|
||||||
# 超时未退出时打印警告,便于排查问题
|
|
||||||
if self.is_alive():
|
|
||||||
algo_log("警告:滤波线程在1秒内未正常退出,可能存在阻塞操作", level="WARNING")
|
|
||||||
|
|
||||||
# 3. 无论线程是否启动,都打印停止日志
|
|
||||||
algo_log("滤波线程已停止")
|
|
||||||
420
Zmq/zmqServer.py
420
Zmq/zmqServer.py
@@ -1,391 +1,241 @@
|
|||||||
# -*-coding:utf-8 -*-
|
|
||||||
import ast
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import zmq
|
||||||
import threading
|
import threading
|
||||||
import json
|
import json
|
||||||
import queue
|
import queue
|
||||||
from typing import Dict
|
# from Device.SunnyLinker import SunnyLinker64
|
||||||
import datetime
|
from dataBuffer import ParadigmRingBuffer
|
||||||
import time
|
from filterProcess import FilterRingBuffer
|
||||||
|
|
||||||
from Zmq.dataBuffer import ParadigmRingBuffer
|
|
||||||
from Zmq.filterProcess import FilterRingBuffer
|
|
||||||
from PubLibrary.InifileHelper import IniRead
|
|
||||||
from logs.log import algo_log
|
from logs.log import algo_log
|
||||||
|
|
||||||
import zmq
|
|
||||||
|
|
||||||
class zmqServer(threading.Thread):
|
class zmqServer(threading.Thread):
|
||||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.device_info = device_info
|
|
||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
self.cmd_port = cmd_port # 命令交互端口
|
||||||
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
self.data_port = data_port # 数据接收端口
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
# 原有业务状态变量
|
# 原有业务状态变量
|
||||||
self.open_Impedance = False #当前系统处于阻抗检测状态
|
# self.get_Impedance = False # 是否返回阻抗值
|
||||||
self.StartDecode = False
|
# self.open_Impedance = None # 是否开启阻抗检测功能
|
||||||
self.StartTrain = False
|
self.StartDecode = False # false 停止解码,true=开始解码
|
||||||
self.state_mode = None
|
self.StartTrain = False # False未进入训练状态,True处于训练状态
|
||||||
self.currentLabel = -1
|
self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态
|
||||||
self.IsExitApp = False
|
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
|
||||||
|
self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。
|
||||||
|
# self.getReport = False # 获取训练报告内容
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
|
|
||||||
# 双环形缓冲区
|
# 范式数据缓存
|
||||||
self.paradigmBuffer = ParadigmRingBuffer(
|
self.paradigmBuffer = ParadigmRingBuffer(66, 2500)
|
||||||
self.device_info['channel_nums'],
|
self.filterBuffer = FilterRingBuffer(66, 2500)
|
||||||
self.device_info['sample_rate'] * 10
|
|
||||||
)
|
|
||||||
self.filterBuffer = FilterRingBuffer(
|
|
||||||
self.device_info['channel_nums'],
|
|
||||||
self.device_info['sample_rate'] * 10
|
|
||||||
)
|
|
||||||
self.paradigmBufferLock = threading.Lock()
|
|
||||||
self.filterBufferLock = threading.Lock()
|
|
||||||
|
|
||||||
# ZMQ上下文与套接字
|
|
||||||
|
# 命令与数据通信
|
||||||
self.context = zmq.Context()
|
self.context = zmq.Context()
|
||||||
|
# 指令通道 (8099) - ROUTER:短JSON命令,低频率
|
||||||
# 8099命令端口:ROUTER
|
|
||||||
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
||||||
self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100)
|
self.cmd_socket.setsockopt(zmq.RCVHWM, 100) # 指令不需要大缓存,100条足够
|
||||||
self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100)
|
self.cmd_socket.setsockopt(zmq.SNDHWM, 100)
|
||||||
|
self.cmd_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,降低指令延迟
|
||||||
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
||||||
|
|
||||||
# 8100数据端口:ROUTER
|
# 数据通道 (8100) - ROUTER:高频脑电二进制流
|
||||||
self.data_socket = self.context.socket(zmq.ROUTER)
|
self.data_socket = self.context.socket(zmq.ROUTER)
|
||||||
self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500)
|
self.data_socket.setsockopt(zmq.RCVHWM, 500) # 500包=10秒缓存,足够应对短时卡顿
|
||||||
self.data_socket.setsockopt(zmq.SocketOption.SNDHWM, 100) # 添加发送高水位线
|
self.data_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,减少数据传输延迟
|
||||||
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
||||||
|
|
||||||
# Poller轮询器
|
# Poller 轮训器(保持不变)
|
||||||
self.poller = zmq.Poller()
|
self.poller = zmq.Poller()
|
||||||
self.poller.register(self.cmd_socket, zmq.POLLIN)
|
self.poller.register(self.cmd_socket, zmq.POLLIN)
|
||||||
self.poller.register(self.data_socket, zmq.POLLIN)
|
self.poller.register(self.data_socket, zmq.POLLIN)
|
||||||
|
|
||||||
# 业务变量
|
# 业务变量
|
||||||
self.targetFreqs = []
|
self.targetFreqs = []
|
||||||
self.changeTarget = False
|
self.changeTarget = False # 更换目标频率
|
||||||
|
# self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化
|
||||||
self.labels = [0x01, 0x02,0x03]
|
self.labels = [0x01, 0x02,0x03]
|
||||||
self.decoder_switch = False
|
self.decoder_switch = False #更换解码器
|
||||||
self.decoder_class = None
|
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
|
||||||
|
|
||||||
# 客户端管理(单客户端场景)
|
# 客户端管理 - 区分命令/数据客户端
|
||||||
self.cmd_clients = set()
|
self.cmd_clients = set() # 命令端口客户端ID
|
||||||
self.data_clients = set()
|
self.data_clients = set() # 数据端口客户端ID
|
||||||
self.current_data_client = None # 唯一数据客户端身份,用于发送滤波结果
|
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
|
||||||
|
|
||||||
# 发送队列(双端口分离)
|
|
||||||
self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列
|
|
||||||
self.data_send_queue = queue.Queue() # 8100端口滤波数据队列
|
|
||||||
|
|
||||||
# 范式buffer与事件检测参数
|
|
||||||
self.predict_event = 99
|
|
||||||
self.events = [1, 2, self.predict_event]
|
|
||||||
self.latency = 50
|
|
||||||
self.train_latency = 50
|
|
||||||
self.count_events = {}
|
|
||||||
self.epoch_finished = False
|
|
||||||
self.pack_contain_event = False
|
|
||||||
self.event_inner_idx = -1
|
|
||||||
self.interval_inited = False
|
|
||||||
|
|
||||||
def reset_state(self):
|
|
||||||
"""清空采集器状态和缓存数据"""
|
|
||||||
with self.paradigmBufferLock:
|
|
||||||
self.paradigmBuffer.resetAllPara()
|
|
||||||
self.count_events = {}
|
|
||||||
self.epoch_finished = False
|
|
||||||
self.pack_contain_event = False
|
|
||||||
self.event_inner_idx = -1
|
|
||||||
self.interval_inited = False
|
|
||||||
|
|
||||||
def interval_init(self, decoder_class):
|
|
||||||
if decoder_class == 'ssmvep':
|
|
||||||
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
|
||||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch]
|
|
||||||
self.train_epoch = [
|
|
||||||
int(self.interval_epoch[0]),
|
|
||||||
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])
|
|
||||||
]
|
|
||||||
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
|
||||||
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
|
||||||
|
|
||||||
elif decoder_class == 'mi':
|
|
||||||
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
|
||||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch]
|
|
||||||
self.train_epoch = self.interval_epoch.copy()
|
|
||||||
self.latency = self.interval_epoch[1] // 5
|
|
||||||
self.train_latency = self.latency
|
|
||||||
|
|
||||||
algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO")
|
|
||||||
self.count_events: Dict[str, int] = {}
|
|
||||||
self.event_inner_idx = -1
|
|
||||||
self.epoch_finished = False
|
|
||||||
self.pack_contain_event = False
|
|
||||||
self.predict_event = 99
|
|
||||||
self.events = [1, 2, self.predict_event]
|
|
||||||
self.interval_inited = True
|
|
||||||
|
|
||||||
# -------------------------- 8099端口:命令结果广播 --------------------------
|
|
||||||
def broadcast_message(self, method, params):
|
def broadcast_message(self, method, params):
|
||||||
"""
|
"""Put message into queue to be sent to all command clients"""
|
||||||
向所有8099端口客户端广播JSON格式的命令结果
|
self.send_queue.put((method, params))
|
||||||
用于:解码结果、训练状态、错误提示、进度通知等
|
|
||||||
"""
|
|
||||||
self.cmd_send_queue.put((method, params))
|
|
||||||
|
|
||||||
def _process_cmd_send_queue(self):
|
|
||||||
"""处理8099端口发送队列,在主线程执行(保证ZMQ线程安全)"""
|
|
||||||
while not self.cmd_send_queue.empty():
|
|
||||||
method, params = self.cmd_send_queue.get()
|
|
||||||
if not self.cmd_clients:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
msg = {'method': method, 'params': params}
|
|
||||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
|
||||||
|
|
||||||
algo_log(f"发送命令结果: {msg}", level="DEBUG")
|
|
||||||
|
|
||||||
# 广播到所有命令客户端
|
|
||||||
for client_id in list(self.cmd_clients):
|
|
||||||
try:
|
|
||||||
self.cmd_socket.send_multipart([client_id, b"", msg_bytes])
|
|
||||||
except Exception as e:
|
|
||||||
algo_log(f"向命令客户端{client_id}发送失败: {e}", level="ERROR")
|
|
||||||
self.cmd_clients.discard(client_id)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
algo_log(f"命令结果打包失败: {e}", level="ERROR")
|
|
||||||
|
|
||||||
# -------------------------- 8100端口:滤波结果发送 --------------------------
|
|
||||||
def send_filtered_data(self, filtered_data):
|
|
||||||
"""
|
|
||||||
向8100端口客户端发送二进制格式的滤波结果
|
|
||||||
用于:上位机实时绘图的脑电波形数据
|
|
||||||
:param filtered_data: 滤波后数据,shape=(通道数, 50),float64格式
|
|
||||||
"""
|
|
||||||
if self.current_data_client is None:
|
|
||||||
algo_log("数据客户端未连接,跳过滤波数据发送", level="WARNING")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 转置为上位机需要的[50, 通道数]格式
|
|
||||||
filtered_data = filtered_data.T.astype(np.float64)
|
|
||||||
send_buf = filtered_data.tobytes()
|
|
||||||
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG")
|
|
||||||
self.data_send_queue.put(send_buf)
|
|
||||||
|
|
||||||
def _process_data_send_queue(self):
|
|
||||||
"""处理8100端口发送队列,在主线程执行(保证ZMQ线程安全)"""
|
|
||||||
while not self.data_send_queue.empty():
|
|
||||||
send_buf = self.data_send_queue.get()
|
|
||||||
if self.current_data_client is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 标准ROUTER发送格式:[客户端ID, 空分隔帧, 数据帧]
|
|
||||||
self.data_socket.send_multipart([
|
|
||||||
self.current_data_client,
|
|
||||||
b"",
|
|
||||||
send_buf
|
|
||||||
])
|
|
||||||
algo_log(f"发送滤波数据成功,长度: {len(send_buf)}字节", level="DEBUG")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
algo_log(f"发送滤波数据失败: {e}", level="ERROR")
|
|
||||||
# 客户端断开,重置身份
|
|
||||||
self.current_data_client = None
|
|
||||||
self.data_clients.clear()
|
|
||||||
|
|
||||||
# -------------------------- 命令端口消息处理 --------------------------
|
|
||||||
def _handle_cmd_message(self, frames):
|
def _handle_cmd_message(self, frames):
|
||||||
"""处理8099端口JSON命令消息"""
|
"""处理命令端口消息(原有命令交互逻辑)"""
|
||||||
if len(frames) < 3:
|
if len(frames) < 3:
|
||||||
algo_log(f"无效命令帧:长度不足3帧,实际{len(frames)}", level="ERROR")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
ident, _, message_bytes = frames[:3]
|
ident, _, message_bytes = frames[:3]
|
||||||
|
|
||||||
# 注册新的命令客户端
|
# 注册新的命令客户端
|
||||||
if ident not in self.cmd_clients:
|
if ident not in self.cmd_clients:
|
||||||
self.cmd_clients.add(ident)
|
self.cmd_clients.add(ident)
|
||||||
algo_log(f"新命令客户端连接成功: {ident}", level="INFO")
|
print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
|
||||||
|
|
||||||
# 解析JSON命令
|
# 解析消息
|
||||||
try:
|
try:
|
||||||
message = json.loads(message_bytes.decode('utf-8'))
|
message = json.loads(message_bytes.decode('utf-8'))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
|
print(f"Invalid JSON from CMD client {ident}")
|
||||||
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
|
continue
|
||||||
return
|
print(f"Received CMD request: {message}")
|
||||||
|
|
||||||
algo_log(f"收到命令: {message}", level="INFO")
|
|
||||||
method = message.get("method")
|
method = message.get("method")
|
||||||
params = message.get("params")
|
params = message.get("params")
|
||||||
|
|
||||||
# 命令处理逻辑
|
# 原有命令处理逻辑
|
||||||
if method == "sync":
|
if method == "sync":
|
||||||
self.state_mode = 'sync'
|
self.state_mode = 'sync'
|
||||||
elif method == "targetFreqs":
|
if method == "targetFreqs":
|
||||||
if not isinstance(params, list):
|
if not isinstance(params, list):
|
||||||
algo_log(f"targetFreqs must be a list")
|
print('targetFreqs must be a list')
|
||||||
return
|
continue
|
||||||
if params != self.targetFreqs:
|
if params != self.targetFreqs:
|
||||||
self.targetFreqs = params
|
self.targetFreqs = params
|
||||||
self.changeTarget = True
|
self.changeTarget = True
|
||||||
elif method == "decoderClass":
|
if method == "decoderClass":
|
||||||
if not isinstance(params, str):
|
if not isinstance(params, str):
|
||||||
algo_log(f"decoderClass必须是字符串")
|
print('decoderClass must be a str')
|
||||||
return
|
continue
|
||||||
if params != self.decoder_class:
|
if params != self.decoder_class:
|
||||||
self.decoder_class = params
|
self.decoder_class = params
|
||||||
self.decoder_switch = True
|
self.decoder_switch = True
|
||||||
elif method == "train":
|
if method == "getReport":
|
||||||
|
self.getReport = True
|
||||||
|
if method == "train":#训练状态
|
||||||
self.state_mode = 'train'
|
self.state_mode = 'train'
|
||||||
self.StartTrain = True
|
self.StartTrain = True
|
||||||
self.currentLabel = params
|
self.currentLabel = params # 当前刺激端的训练标签
|
||||||
elif method == "predict":
|
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||||
|
elif method == "predict":#预测状态
|
||||||
self.state_mode = 'predict'
|
self.state_mode = 'predict'
|
||||||
if params == 1: #开始解码
|
if params == 1: #开始解码
|
||||||
self.StartDecode = True
|
self.StartDecode = True
|
||||||
|
self.sunnyLinker.push_trigger(0x63)
|
||||||
elif params == 2: #停止解码
|
elif params == 2: #停止解码
|
||||||
self.IsExitApp = True
|
self.IsExitApp = True
|
||||||
self.running = False
|
self.running = False
|
||||||
elif method == "rest":
|
elif method == "rest": #休息状态
|
||||||
self.state_mode = 'rest'
|
self.state_mode = 'rest'
|
||||||
elif method == "impedance":
|
# elif method == "impedance":
|
||||||
if params == 1:
|
# if params == 1:
|
||||||
self.open_Impedance = True
|
# self.open_Impedance = True # 开启阻抗
|
||||||
elif params == 2:
|
# self.get_Impedance = True # 返回阻抗
|
||||||
self.open_Impedance = False
|
# elif params == 2:
|
||||||
else:
|
# self.open_Impedance = False # 关闭阻抗
|
||||||
self.broadcast_message("error", {"code": 404, "message": f"未知命令: {method}"})
|
# self.get_Impedance = False # 停止返回阻抗
|
||||||
|
|
||||||
# -------------------------- 数据端口消息处理 --------------------------
|
|
||||||
def _handle_data_message(self, frames):
|
def _handle_data_message(self, frames):
|
||||||
"""处理8100端口二进制脑电数据消息"""
|
"""
|
||||||
algo_log(f"收到数据帧,总帧数:{len(frames)}", level="DEBUG", record_once=False)
|
处理8100端口原始脑电二进制数据
|
||||||
# 然后再进行解析
|
固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区
|
||||||
if len(frames) == 4:
|
"""
|
||||||
# 你的上位机格式
|
# 1. 校验ZMQ消息帧完整性
|
||||||
ident, sender_ident, empty_sep, data_bytes = frames[:4]
|
if len(frames) < 3:
|
||||||
elif len(frames) == 3:
|
print(f"[ERROR] 无效数据帧:长度不足3帧,实际长度={len(frames)}")
|
||||||
# 标准格式
|
|
||||||
ident, empty_sep, data_bytes = frames[:3]
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
# 注册新的数据客户端(单客户端场景,自动覆盖旧身份)
|
|
||||||
if ident not in self.data_clients:
|
|
||||||
self.data_clients.clear() # 单客户端,只保留最新连接
|
|
||||||
self.data_clients.add(ident)
|
|
||||||
self.current_data_client = ident
|
|
||||||
algo_log(f"新数据客户端连接成功: {ident}", level="INFO")
|
|
||||||
try:
|
|
||||||
# 精确长度校验
|
|
||||||
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4
|
|
||||||
if len(data_bytes) != EXPECTED_BYTES:
|
|
||||||
algo_log(f"数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 零拷贝解析 + 维度转换
|
ident, _, data_bytes = frames[:3]
|
||||||
data_np = np.frombuffer(data_bytes, dtype=np.float64)
|
|
||||||
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
|
# 2. 客户端管理(单客户端场景,自动更新最新身份)
|
||||||
|
if ident not in self.data_clients:
|
||||||
|
self.data_clients.add(ident)
|
||||||
|
self.current_data_client = ident # 保存唯一客户端身份,用于后续回复滤波结果
|
||||||
|
print(f"[INFO] 新数据客户端连接成功:{ident}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节,与int32字节数相同)
|
||||||
|
EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节
|
||||||
|
if len(data_bytes) != EXPECTED_BYTES:
|
||||||
|
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4. 零拷贝二进制解析 + 维度转换
|
||||||
|
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
|
||||||
|
data_np = np.frombuffer(data_bytes, dtype=np.float32)
|
||||||
|
# 重塑为上位机原始维度
|
||||||
|
data_np = data_np.reshape(5, 66)
|
||||||
|
# 转置为(通道数, 采样点数)标准格式,转换为float64保证滤波运算精度
|
||||||
data_np = data_np.T.astype(np.float64)
|
data_np = data_np.T.astype(np.float64)
|
||||||
|
|
||||||
# 写入滤波缓冲区
|
# 5. 同时写入双环形缓冲区(方法名与现有类保持一致:appendBuffer)
|
||||||
with self.filterBufferLock:
|
# 注意:上位机已发送微伏物理量,无需再乘以增益系数
|
||||||
|
self.paradigmBuffer.appendBuffer(data_np)
|
||||||
self.filterBuffer.appendBuffer(data_np)
|
self.filterBuffer.appendBuffer(data_np)
|
||||||
|
|
||||||
# 写入范式缓冲区
|
# 生产环境必须注释!每秒50次打印会导致CPU占用飙升30%以上
|
||||||
with self.paradigmBufferLock:
|
algo_log(f"数据写入成功:shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG", record_once=True)
|
||||||
if self.interval_inited:
|
|
||||||
self.epoch_finished = self.detect_event(data_np)
|
|
||||||
if self.pack_contain_event:
|
|
||||||
self.paradigmBuffer.resetAllPara()
|
|
||||||
self.paradigmBuffer.appendBuffer(data_np)
|
|
||||||
if self.epoch_finished:
|
|
||||||
algo_log('Epoch采集完成: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG")
|
|
||||||
else:
|
|
||||||
self.paradigmBuffer.appendBuffer(data_np)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"数据处理失败: {str(e)}", level="ERROR")
|
algo_log(f"数据处理失败:{str(e)}", level="ERROR")
|
||||||
if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG':
|
# 调试阶段临时打开,生产环境务必注释
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
# -------------------------- 事件检测 --------------------------
|
def _process_send_queue(self):
|
||||||
def detect_event(self, samples):
|
"""处理发送队列,向所有命令客户端广播消息"""
|
||||||
self.pack_contain_event = False
|
while not self.send_queue.empty():
|
||||||
# 第65通道为事件通道
|
method, params = self.send_queue.get()
|
||||||
events = np.array(samples[-2])[0].tolist()
|
if self.cmd_clients:
|
||||||
for idx, event in enumerate(events):
|
try:
|
||||||
if int(event) in self.events:
|
msg = {'method': method, 'params': params}
|
||||||
new_key = "".join(
|
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||||
[
|
|
||||||
str(event),
|
# 打印日志(隐藏大尺寸数据)
|
||||||
datetime.datetime.now().strftime("%Y-%m-%d \
|
if method in ['single_trial_plot', 'miReport']:
|
||||||
-%H-%M-%S"),
|
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
|
||||||
]
|
|
||||||
)
|
|
||||||
if event == self.predict_event:
|
|
||||||
self.count_events[new_key] = self.latency + 1
|
|
||||||
else:
|
else:
|
||||||
self.count_events[new_key] = self.train_latency + 1
|
print(f"Sending CMD message: {msg}")
|
||||||
self.event_inner_idx = idx
|
|
||||||
self.pack_contain_event = True
|
|
||||||
|
|
||||||
# 倒计时并清理过期事件
|
# 广播到所有命令客户端
|
||||||
drop_items = []
|
for client_id in list(self.cmd_clients):
|
||||||
for key, value in self.count_events.items():
|
try:
|
||||||
value -= 1
|
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
|
||||||
if value == 0:
|
except Exception as e:
|
||||||
drop_items.append(key)
|
print(f"Error sending to CMD client {client_id}: {e}")
|
||||||
self.count_events[key] = value
|
self.cmd_clients.discard(client_id) # 移除失效客户端
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error preparing broadcast: {e}")
|
||||||
|
|
||||||
for key in drop_items:
|
|
||||||
del self.count_events[key]
|
|
||||||
|
|
||||||
if drop_items:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
# -------------------------- 主循环 --------------------------
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.running = True
|
self.running = True
|
||||||
algo_log(f"ZMQ服务器启动成功 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
|
print(f"ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while self.running:
|
while self.running:
|
||||||
# 1. 处理两个端口的发送队列(必须在主线程执行)
|
# 1. 处理发送队列(命令端口广播)
|
||||||
self._process_cmd_send_queue()
|
self._process_send_queue()
|
||||||
self._process_data_send_queue()
|
|
||||||
|
|
||||||
# 2. 轮询监听两个端口的输入事件
|
# 2. 轮训监听两个Socket的输入事件(10ms超时,避免阻塞)
|
||||||
socks = dict(self.poller.poll(50))
|
socks = dict(self.poller.poll(10))
|
||||||
|
|
||||||
# 处理8099命令端口消息
|
# 处理命令端口消息
|
||||||
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
||||||
frames = self.cmd_socket.recv_multipart()
|
frames = self.cmd_socket.recv_multipart()
|
||||||
self._handle_cmd_message(frames)
|
self._handle_cmd_message(frames)
|
||||||
|
|
||||||
# 处理8100数据端口消息
|
# 处理数据端口消息
|
||||||
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
||||||
frames = self.data_socket.recv_multipart()
|
frames = self.data_socket.recv_multipart()
|
||||||
self._handle_data_message(frames)
|
self._handle_data_message(frames)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"服务器主循环异常: {e}", level="ERROR")
|
print(f"Server error occurred: {e}")
|
||||||
finally:
|
finally:
|
||||||
self.running = False
|
self.running = False
|
||||||
# 优雅关闭所有资源
|
# 关闭所有Socket和上下文
|
||||||
self.cmd_socket.close()
|
self.cmd_socket.close()
|
||||||
self.data_socket.close()
|
self.data_socket.close()
|
||||||
self.context.term()
|
self.context.term()
|
||||||
algo_log("ZMQ服务器已关闭", level="INFO")
|
print("Server sockets and context closed.")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""显式关闭服务器"""
|
"""显式关闭服务器"""
|
||||||
@@ -393,10 +243,10 @@ class zmqServer(threading.Thread):
|
|||||||
self.cmd_socket.close()
|
self.cmd_socket.close()
|
||||||
self.data_socket.close()
|
self.data_socket.close()
|
||||||
self.context.term()
|
self.context.term()
|
||||||
algo_log(f"服务器已显式关闭 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
|
print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 初始化并启动服务器
|
# 初始化并启动服务器(默认cmd=8099, data=8100)
|
||||||
server = zmqServer()
|
server = zmqServer()
|
||||||
server.start()
|
server.start()
|
||||||
|
|
||||||
@@ -405,5 +255,5 @@ if __name__ == '__main__':
|
|||||||
while server.running:
|
while server.running:
|
||||||
threading.Event().wait(1)
|
threading.Event().wait(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
algo_log("收到键盘中断信号,正在停止服务器...", level="INFO")
|
print("Received KeyboardInterrupt, stopping server...")
|
||||||
server.stop()
|
server.stop()
|
||||||
445
Zmq/zmqServer1.py
Normal file
445
Zmq/zmqServer1.py
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
import threading
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
from Device.SunnyLinker import SunnyLinker64, RingBuffer
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class zmqServer(threading.Thread):
|
||||||
|
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.host = host
|
||||||
|
self.cmd_port = cmd_port
|
||||||
|
self.data_port = data_port
|
||||||
|
self.running = False
|
||||||
|
self.get_Impedance = False
|
||||||
|
self.open_Impedance = None
|
||||||
|
self.StartDecode = False
|
||||||
|
self.StartTrain = False
|
||||||
|
self.state_mode = None
|
||||||
|
self.currentLabel = -1
|
||||||
|
self.IsExitApp = False
|
||||||
|
self.getReport = False
|
||||||
|
self.daemon = True
|
||||||
|
|
||||||
|
# ZMQ Context
|
||||||
|
self.context = zmq.Context()
|
||||||
|
|
||||||
|
# 指令通道 (8099) - ROUTER
|
||||||
|
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
||||||
|
self.cmd_socket.setsockopt(zmq.RCVHWM, 1000)
|
||||||
|
self.cmd_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||||
|
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
||||||
|
|
||||||
|
# 数据通道 (8100)) - ROUTER
|
||||||
|
self.data_socket = self.context.socket(zmq.ROUTER)
|
||||||
|
self.data_socket.setsockopt(zmq.RCVHWM, 1000)
|
||||||
|
self.data_socket.setsockopt(zmq.RCVTIMEO, 50)
|
||||||
|
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
||||||
|
|
||||||
|
self.targetFreqs = []
|
||||||
|
self.changeTarget = False
|
||||||
|
self.sunnyLinker = SunnyLinker64(None, None, None, None, None)
|
||||||
|
self.labels = [0x01, 0x02, 0x03]
|
||||||
|
|
||||||
|
self.decoder_switch = False
|
||||||
|
self.decoder_class = None
|
||||||
|
self.cmd_clients = set()
|
||||||
|
self.data_clients = set()
|
||||||
|
self.send_queue = queue.Queue()
|
||||||
|
|
||||||
|
# ========== 数据缓冲区 (RingBuffer) ==========
|
||||||
|
# 与 SunnyLinker 保持一致,使用 RingBuffer
|
||||||
|
# 66 = 64 EEG通道 + 1 事件通道(第65) + 1 标签序号通道(第66)
|
||||||
|
# 缓存约 10 秒数据 (250Hz * 10s = 2500 点)
|
||||||
|
self.n_chan = 66
|
||||||
|
self.t_buffer = 10.0 # 缓冲区时长(秒)
|
||||||
|
self.__ringBuffer = RingBuffer(self.n_chan, int(self.t_buffer * 250))
|
||||||
|
|
||||||
|
# 事件检测相关
|
||||||
|
self._event_lock = threading.Lock()
|
||||||
|
self._epoch_finished = False
|
||||||
|
self._event_inner_idx = -1
|
||||||
|
self.pack_contain_event = False
|
||||||
|
self.predict_event = 99
|
||||||
|
self.events = [1, 2, self.predict_event]
|
||||||
|
self.count_events = {}
|
||||||
|
self.latency = 50
|
||||||
|
self.train_latency = 50
|
||||||
|
|
||||||
|
# 当前事件标签序号 (从第66通道获取)
|
||||||
|
self.current_label_index = 0
|
||||||
|
|
||||||
|
# 初始化标志
|
||||||
|
self._interval_inited = False
|
||||||
|
self._currentLabel = -1
|
||||||
|
|
||||||
|
# 注册的客户端(兼容旧接口)
|
||||||
|
self.clients = set()
|
||||||
|
|
||||||
|
# ========== 事件属性:线程安全访问 ==========
|
||||||
|
@property
|
||||||
|
def epoch_finished(self):
|
||||||
|
with self._event_lock:
|
||||||
|
return self._epoch_finished
|
||||||
|
|
||||||
|
@epoch_finished.setter
|
||||||
|
def epoch_finished(self, value):
|
||||||
|
with self._event_lock:
|
||||||
|
self._epoch_finished = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def event_inner_idx(self):
|
||||||
|
with self._event_lock:
|
||||||
|
return self._event_inner_idx
|
||||||
|
|
||||||
|
@event_inner_idx.setter
|
||||||
|
def event_inner_idx(self, value):
|
||||||
|
with self._event_lock:
|
||||||
|
self._event_inner_idx = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interval_inited(self):
|
||||||
|
return self._interval_inited
|
||||||
|
|
||||||
|
@interval_inited.setter
|
||||||
|
def interval_inited(self, value):
|
||||||
|
self._interval_inited = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def currentLabel(self):
|
||||||
|
return self._currentLabel
|
||||||
|
|
||||||
|
@currentLabel.setter
|
||||||
|
def currentLabel(self, value):
|
||||||
|
self._currentLabel = value
|
||||||
|
|
||||||
|
def broadcast_message(self, method, params):
|
||||||
|
"""Put message into queue to be sent to all connected clients"""
|
||||||
|
self.send_queue.put((method, params))
|
||||||
|
|
||||||
|
# ========== 数据缓冲区操作接口 ==========
|
||||||
|
def GetDataLenCount(self):
|
||||||
|
"""返回缓冲区当前数据点数"""
|
||||||
|
return self.__ringBuffer.nUpdate
|
||||||
|
|
||||||
|
def getData(self, count):
|
||||||
|
"""获取最新count个数据点,不消费(只读)"""
|
||||||
|
with self.__ringBuffer.RingBufferLock:
|
||||||
|
count = min(count, self.__ringBuffer.nUpdate)
|
||||||
|
if count == 0:
|
||||||
|
return np.zeros((self.n_chan, 0))
|
||||||
|
|
||||||
|
# 计算读取范围(从尾部取最新数据)
|
||||||
|
read_end = (self.__ringBuffer.currentPtr - 1) % self.__ringBuffer.n_points
|
||||||
|
read_start = (read_end - count + 1) % self.__ringBuffer.n_points
|
||||||
|
|
||||||
|
if self.__ringBuffer.currentPtr == 0:
|
||||||
|
read_start = self.__ringBuffer.n_points - count
|
||||||
|
read_end = self.__ringBuffer.n_points - 1
|
||||||
|
|
||||||
|
if read_start <= read_end:
|
||||||
|
data = self.__ringBuffer.buffer[:, read_start:read_end + 1]
|
||||||
|
else:
|
||||||
|
part1 = self.__ringBuffer.buffer[:, read_start:]
|
||||||
|
part2 = self.__ringBuffer.buffer[:, :read_end + 1]
|
||||||
|
data = np.concatenate((part1, part2), axis=1)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def consumeData(self, count):
|
||||||
|
"""消费(丢弃)指定数量的数据点,从头部移除"""
|
||||||
|
with self.__ringBuffer.RingBufferLock:
|
||||||
|
count = min(count, self.__ringBuffer.nUpdate)
|
||||||
|
self.__ringBuffer.readPtr = (self.__ringBuffer.readPtr + count) % self.__ringBuffer.n_points
|
||||||
|
self.__ringBuffer.nUpdate -= count
|
||||||
|
|
||||||
|
def ResetAll(self):
|
||||||
|
"""重置缓冲区"""
|
||||||
|
with self.__ringBuffer.RingBufferLock:
|
||||||
|
self.__ringBuffer.resetAllPara()
|
||||||
|
with self._event_lock:
|
||||||
|
self._epoch_finished = False
|
||||||
|
self._event_inner_idx = -1
|
||||||
|
self.pack_contain_event = False
|
||||||
|
self.count_events.clear()
|
||||||
|
self.current_label_index = 0
|
||||||
|
|
||||||
|
def reset_data_buffer(self):
|
||||||
|
self.ResetAll()
|
||||||
|
|
||||||
|
def reset_state(self):
|
||||||
|
self.ResetAll()
|
||||||
|
|
||||||
|
def interval_init(self, decoder_class):
|
||||||
|
"""初始化事件检测参数"""
|
||||||
|
import ast
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
|
if decoder_class == 'ssmvep':
|
||||||
|
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||||
|
self.interval_epoch = [int(i * 250) for i in interval_epoch]
|
||||||
|
self.train_epoch = [int(self.interval_epoch[0]),
|
||||||
|
int(self.interval_epoch[1] + 0.1 * 250)]
|
||||||
|
self.latency = (self.interval_epoch[1] + 0.1 * 250) // 5
|
||||||
|
self.train_latency = (self.train_epoch[1] + 0.1 * 250) // 5
|
||||||
|
|
||||||
|
elif decoder_class == 'mi':
|
||||||
|
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||||
|
self.interval_epoch = [int(i * 250) for i in interval_epoch]
|
||||||
|
self.train_epoch = self.interval_epoch.copy()
|
||||||
|
self.latency = self.interval_epoch[1] // 5
|
||||||
|
self.train_latency = self.latency
|
||||||
|
|
||||||
|
self.count_events = {}
|
||||||
|
self._event_inner_idx = -1
|
||||||
|
self._epoch_finished = False
|
||||||
|
self.pack_contain_event = False
|
||||||
|
self.predict_event = 99
|
||||||
|
self.events = [1, 2, self.predict_event]
|
||||||
|
self._interval_inited = True
|
||||||
|
|
||||||
|
# ========== 事件检测 ==========
|
||||||
|
def detect_event(self, data_matrix):
|
||||||
|
"""
|
||||||
|
检测事件通道中的触发信号
|
||||||
|
|
||||||
|
@param data_matrix: shape (66, N) - N个采样点的数据
|
||||||
|
第65行(索引64) = 事件通道
|
||||||
|
第66行(索引65) = 标签通道
|
||||||
|
@return: 是否检测到事件
|
||||||
|
"""
|
||||||
|
if data_matrix.shape[1] == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.pack_contain_event = False
|
||||||
|
event_channel = data_matrix[64, :] # 第65通道 = 标签值(event值)
|
||||||
|
label_channel = data_matrix[65, :] # 第66通道 = 标签序号(label index)
|
||||||
|
|
||||||
|
events = event_channel.tolist()
|
||||||
|
|
||||||
|
with self._event_lock:
|
||||||
|
self._event_inner_idx = -1
|
||||||
|
self.current_event_label = 0
|
||||||
|
|
||||||
|
for idx, event in enumerate(events):
|
||||||
|
if int(event) in self.events:
|
||||||
|
self._event_inner_idx = idx
|
||||||
|
self.current_label_index = int(label_channel[idx])
|
||||||
|
self.pack_contain_event = True
|
||||||
|
|
||||||
|
new_key = f"{event}_{time.time()}"
|
||||||
|
latency = self.latency if event == self.predict_event else self.train_latency
|
||||||
|
self.count_events[new_key] = latency + 1
|
||||||
|
|
||||||
|
# 延迟计数递减
|
||||||
|
drop_items = []
|
||||||
|
for key, value in self.count_events.items():
|
||||||
|
value = value - 1
|
||||||
|
if value == 0:
|
||||||
|
drop_items.append(key)
|
||||||
|
self.count_events[key] = value
|
||||||
|
for key in drop_items:
|
||||||
|
del self.count_events[key]
|
||||||
|
|
||||||
|
if drop_items:
|
||||||
|
self._epoch_finished = True
|
||||||
|
# 检测到事件时,清除RingBuffer中之前的数据,只保留当前包
|
||||||
|
if self.pack_contain_event:
|
||||||
|
self.__ringBuffer.resetAllPara()
|
||||||
|
return True
|
||||||
|
|
||||||
|
self._epoch_finished = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.running = True
|
||||||
|
print(f"Server running - CMD: {self.cmd_port}, DATA: {self.data_port}")
|
||||||
|
|
||||||
|
cmd_poller = zmq.Poller()
|
||||||
|
cmd_poller.register(self.cmd_socket, zmq.POLLIN)
|
||||||
|
|
||||||
|
data_poller = zmq.Poller()
|
||||||
|
data_poller.register(self.data_socket, zmq.POLLIN)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.running:
|
||||||
|
# --- 处理发送队列 (指令通道) ---
|
||||||
|
while not self.send_queue.empty():
|
||||||
|
method, params = self.send_queue.get()
|
||||||
|
if self.cmd_clients:
|
||||||
|
try:
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||||
|
for client_id in list(self.cmd_clients):
|
||||||
|
try:
|
||||||
|
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# --- 处理指令通道 ---
|
||||||
|
socks = dict(cmd_poller.poll(10))
|
||||||
|
if self.cmd_socket in socks:
|
||||||
|
self._handle_cmd_socket()
|
||||||
|
|
||||||
|
# --- 处理数据通道 ---
|
||||||
|
socks = dict(data_poller.poll(10))
|
||||||
|
if self.data_socket in socks:
|
||||||
|
self._handle_data_socket()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Server error: {e}")
|
||||||
|
finally:
|
||||||
|
self.running = False
|
||||||
|
self.cmd_socket.close()
|
||||||
|
self.data_socket.close()
|
||||||
|
self.context.term()
|
||||||
|
|
||||||
|
def _handle_cmd_socket(self):
|
||||||
|
"""处理指令通道消息"""
|
||||||
|
try:
|
||||||
|
frames = self.cmd_socket.recv_multipart()
|
||||||
|
if len(frames) < 3:
|
||||||
|
return
|
||||||
|
ident, _, message_bytes = frames[:3]
|
||||||
|
self.cmd_clients.add(ident)
|
||||||
|
self.clients.add(ident)
|
||||||
|
|
||||||
|
message = json.loads(message_bytes.decode('utf-8'))
|
||||||
|
method = message.get("method")
|
||||||
|
params = message.get("params")
|
||||||
|
|
||||||
|
print(f"[CMD] {method}: {params}")
|
||||||
|
|
||||||
|
if method == "sync":
|
||||||
|
self.state_mode = 'sync'
|
||||||
|
elif method == "targetFreqs":
|
||||||
|
if isinstance(params, list) and params != self.targetFreqs:
|
||||||
|
self.targetFreqs = params
|
||||||
|
self.changeTarget = True
|
||||||
|
elif method == "decoderClass":
|
||||||
|
if isinstance(params, str) and params != self.decoder_class:
|
||||||
|
self.decoder_class = params
|
||||||
|
self.decoder_switch = True
|
||||||
|
elif method == "getReport":
|
||||||
|
self.getReport = True
|
||||||
|
elif method == "train":
|
||||||
|
self.state_mode = 'train'
|
||||||
|
self.StartTrain = True
|
||||||
|
self.currentLabel = params
|
||||||
|
elif method == "predict":
|
||||||
|
self.state_mode = 'predict'
|
||||||
|
if params == 1:
|
||||||
|
self.StartDecode = True
|
||||||
|
elif params == 2:
|
||||||
|
self.IsExitApp = True
|
||||||
|
self.running = False
|
||||||
|
elif method == "rest":
|
||||||
|
self.state_mode = 'rest'
|
||||||
|
elif method == "impedance":
|
||||||
|
if params == 1:
|
||||||
|
self.open_Impedance = True
|
||||||
|
self.get_Impedance = True
|
||||||
|
elif params == 2:
|
||||||
|
self.open_Impedance = False
|
||||||
|
self.get_Impedance = False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"CMD socket error: {e}")
|
||||||
|
|
||||||
|
def _handle_data_socket(self):
|
||||||
|
"""处理数据通道消息 (EEG数据)
|
||||||
|
|
||||||
|
上位机数据格式:
|
||||||
|
- 数据帧: [identity, '', meta_json, data_buffer]
|
||||||
|
data_buffer = [N, 66] float32 -> 转置为 [66, N]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
frames = self.data_socket.recv_multipart()
|
||||||
|
if len(frames) < 4:
|
||||||
|
return
|
||||||
|
ident, _, message_bytes = frames[:3]
|
||||||
|
self.data_clients.add(ident)
|
||||||
|
|
||||||
|
meta = json.loads(message_bytes.decode('utf-8'))
|
||||||
|
|
||||||
|
# data: [N, 66] -> 转置 -> [66, N]
|
||||||
|
raw_data = np.frombuffer(frames[3], dtype=np.float32)
|
||||||
|
n_samples, n_channels = meta.get('shape', [5, 66])
|
||||||
|
data_matrix = raw_data.reshape(n_samples, n_channels).T.astype(np.float32)
|
||||||
|
|
||||||
|
# 写入 RingBuffer
|
||||||
|
with self.__ringBuffer.RingBufferLock:
|
||||||
|
self.__ringBuffer.appendBuffer(data_matrix)
|
||||||
|
|
||||||
|
# 事件检测
|
||||||
|
self.detect_event(data_matrix)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"DATA socket error: {e}")
|
||||||
|
|
||||||
|
# ========== 各范式数据访问接口 ==========
|
||||||
|
def get_MIData(self):
|
||||||
|
"""获取MI导联数据 (21通道 + 事件)"""
|
||||||
|
data = self.getData(self.GetDataLenCount())
|
||||||
|
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def get_SSMVEPData(self):
|
||||||
|
"""获取SSMVEP导联数据 (8通道 + 事件)"""
|
||||||
|
data = self.getData(self.GetDataLenCount())
|
||||||
|
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def getDataViaSSVEP(self, count):
|
||||||
|
"""获取SSVEP数据 (8通道 + 事件)"""
|
||||||
|
data = self.getData(count)
|
||||||
|
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def get_concentrateData(self, count):
|
||||||
|
"""获取专注力数据 (2通道)"""
|
||||||
|
data = self.getData(count)
|
||||||
|
rows_to_extract = [0, 1]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def get_blinkData(self, count):
|
||||||
|
"""获取眨眼数据 (2通道)"""
|
||||||
|
data = self.getData(count)
|
||||||
|
rows_to_extract = [0, 1]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def getImpedance(self, data, decoder_class):
|
||||||
|
"""计算阻抗(ZMQ模式下不可用)"""
|
||||||
|
return np.zeros(8)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
self.cmd_socket.close()
|
||||||
|
self.data_socket.close()
|
||||||
|
self.context.term()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
server = zmqServer()
|
||||||
|
server.start()
|
||||||
@@ -8,7 +8,6 @@ import os
|
|||||||
# import logging
|
# import logging
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import math
|
|
||||||
|
|
||||||
# logger = logging.getLogger(__name__)
|
# logger = logging.getLogger(__name__)
|
||||||
#
|
#
|
||||||
@@ -23,7 +22,7 @@ import math
|
|||||||
|
|
||||||
|
|
||||||
class Calculate():
|
class Calculate():
|
||||||
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10, config=None):
|
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10):
|
||||||
self.Threshold_value_low = Threshold_value_low
|
self.Threshold_value_low = Threshold_value_low
|
||||||
self.Threshold_value_high = Threshold_value_high
|
self.Threshold_value_high = Threshold_value_high
|
||||||
self.fs = fs
|
self.fs = fs
|
||||||
@@ -32,73 +31,47 @@ class Calculate():
|
|||||||
self.EVI_result = []
|
self.EVI_result = []
|
||||||
self.eegQueue = deque(maxlen=win_len)
|
self.eegQueue = deque(maxlen=win_len)
|
||||||
|
|
||||||
|
# # 存储历史数据用于绘图
|
||||||
|
# self.beta_history = []
|
||||||
|
# self.alpha_history = []
|
||||||
|
# self.theta_history = []
|
||||||
|
# self.focus_history = []
|
||||||
|
# self.timestamp_history = []
|
||||||
|
#
|
||||||
|
# # 记录开始时间
|
||||||
|
# self.start_time = None
|
||||||
|
# self.recording = False
|
||||||
|
#
|
||||||
|
# # 图表保存路径
|
||||||
|
# self.chart_dir = "reports"
|
||||||
|
# if not os.path.exists(self.chart_dir):
|
||||||
|
# os.makedirs(self.chart_dir)
|
||||||
|
# print(f"[调试] 创建目录: {self.chart_dir}")
|
||||||
|
|
||||||
# 初始化滤波器
|
# 初始化滤波器
|
||||||
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
|
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
|
||||||
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
|
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
|
||||||
|
|
||||||
self.last_focus = None
|
|
||||||
# 异步滤波系数配置(核心手感控制纽)
|
|
||||||
self.alpha_up = 1 # 上升系数:较小,保证分数平滑爬升,过滤偶发的瞬时高能量
|
|
||||||
# alpha_down / shrink_factor 从 config.ini 读取,方便上位机调参
|
|
||||||
if config:
|
|
||||||
self.alpha_down = float(config.get('alpha_down', 0.8))
|
|
||||||
self.shrink_factor = float(config.get('shrink_factor', 0.5))
|
|
||||||
else:
|
|
||||||
self.alpha_down = 0.8
|
|
||||||
self.shrink_factor = 0.5
|
|
||||||
print("[调试] Calculate 类初始化完成")
|
print("[调试] Calculate 类初始化完成")
|
||||||
|
|
||||||
def calculate_focus(self, beta, alpha, theta):
|
def calculate_focus(self, beta, alpha, theta):
|
||||||
"""
|
"""
|
||||||
专注度计算 - 三区间门限异步滤波版本
|
专注度计算 - 固定映射版本
|
||||||
"""
|
"""
|
||||||
# 0. 频带特征预处理
|
|
||||||
theta_mod = theta ** 0.7
|
|
||||||
|
|
||||||
# 原始比值
|
# 原始比值
|
||||||
raw = beta / (alpha + theta_mod + 1e-10)
|
raw = beta / (alpha + theta + 1e-10)
|
||||||
|
|
||||||
exponent = 2.0
|
# Sigmoid 映射:让 raw 在 0.3-1.5 区间敏感
|
||||||
|
# 参数可调:
|
||||||
|
# k = 12 (斜率,越大越陡)
|
||||||
|
# x0 = 0.6 (中心点,raw=0.6时focus≈50)
|
||||||
|
k = 12.0
|
||||||
|
x0 = 0.6
|
||||||
|
focus = 100.0 / (1.0 + np.exp(-k * (raw - x0)))
|
||||||
|
|
||||||
# 1. 防止脑电比值出现负数异常值
|
# 可选:添加滑动平均平滑
|
||||||
raw_input = max(raw, 0.0)
|
|
||||||
|
|
||||||
# 2. 2次幂纵轴压缩映射 (shrink_factor 从 config.ini 读取)
|
|
||||||
focus_raw = 100 * self.shrink_factor * (raw_input ** exponent)
|
|
||||||
|
|
||||||
# 3. 计算当前帧的瞬时分数 (基准量级 0-120)
|
|
||||||
instant_focus = 120 * (1.0 - np.exp(-focus_raw / 100.0))
|
|
||||||
|
|
||||||
# 4. 核心修改:三区间门限时域滤波
|
|
||||||
if self.last_focus is None:
|
|
||||||
# 冷启动:首帧直接赋值
|
|
||||||
focus = instant_focus
|
|
||||||
else:
|
|
||||||
# 判断当前瞬时分数是否处于【极端区】(80以上 或 60以下)
|
|
||||||
if instant_focus > 85.0 or instant_focus < 60.0:
|
|
||||||
# 执行异步低通时域滤波
|
|
||||||
if instant_focus >= self.last_focus:
|
|
||||||
# 趋势上升:慢爬升
|
|
||||||
focus = self.alpha_up * instant_focus + (1 - self.alpha_up) * self.last_focus
|
|
||||||
else:
|
|
||||||
# 趋势下降:快跌落
|
|
||||||
focus = self.alpha_down * instant_focus + (1 - self.alpha_down) * self.last_focus
|
|
||||||
else:
|
|
||||||
# 【高灵敏自由区】(60 <= instant_focus <= 80)
|
|
||||||
# 不执行异步滤波,分数直接跟随瞬时值,保证中间状态绝对跟手
|
|
||||||
focus = instant_focus
|
|
||||||
|
|
||||||
# 5. 更新历史状态缓存
|
|
||||||
self.last_focus = focus
|
|
||||||
|
|
||||||
# 打印在线调试日志,方便观察区间切换
|
|
||||||
zone_tag = "极端区(滤波)" if (instant_focus > 80 or instant_focus < 60) else "自由区(直通)"
|
|
||||||
print(f"原始特征比值 raw: {raw:.4f} | 瞬时分数: {instant_focus:.1f} | 滤波后分数: {focus:.1f}")
|
|
||||||
|
|
||||||
# 最终返回整型
|
|
||||||
return int(focus)
|
return int(focus)
|
||||||
|
|
||||||
|
|
||||||
def calculate_all(self, data, fs, nperseg=1000):
|
def calculate_all(self, data, fs, nperseg=1000):
|
||||||
mean_x = np.mean(data, axis=-1, keepdims=True)
|
mean_x = np.mean(data, axis=-1, keepdims=True)
|
||||||
data = data - mean_x
|
data = data - mean_x
|
||||||
@@ -346,16 +319,14 @@ class Calculate():
|
|||||||
if eegData.size == 0:
|
if eegData.size == 0:
|
||||||
return None
|
return None
|
||||||
eegData -= np.mean(eegData, axis=-1, keepdims=True)
|
eegData -= np.mean(eegData, axis=-1, keepdims=True)
|
||||||
# eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) # 陷波
|
eegData = signal.lfilter(self.b_notch, self.a_notch, eegData)
|
||||||
# eegData = signal.lfilter(self.b_design, 1, eegData) # 滤波
|
eegData = signal.lfilter(self.b_design, 1, eegData)
|
||||||
focus_score, CLI_score, beta_psd, alpha_psd, theta_psd = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
|
focus_score, CLI_score, beta, alpha, theta = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
|
||||||
|
|
||||||
# self.add_data_point(focus_score, beta_psd, alpha_psd, theta_psd) # 已注释(方法已移除)
|
# self.add_data_point(focus_score, beta, alpha, theta)
|
||||||
|
|
||||||
# return (focus_score)
|
|
||||||
return (focus_score, beta_psd)
|
|
||||||
# return None
|
|
||||||
|
|
||||||
|
return focus_score
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Calculate2():
|
class Calculate2():
|
||||||
|
|||||||
14
config.ini
14
config.ini
@@ -19,10 +19,10 @@ Serial_port = COM44
|
|||||||
algo_log_level = DEBUG
|
algo_log_level = DEBUG
|
||||||
console_output = 1
|
console_output = 1
|
||||||
|
|
||||||
; 64 导设备配置
|
|
||||||
[device_type_1]
|
; 64 导设备配置 1; 32 2; 24 3; 16 4; 8 5; 4 6;
|
||||||
sample_rate = 250
|
[device_type] = 1
|
||||||
frame_points = 5
|
device_sample_rate = 250
|
||||||
channel_nums = 66
|
device_channel_nums = 66
|
||||||
channel_names = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1', 'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6', 'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1', 'T7', 'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ', 'PZ', 'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1', 'label', 'label_tag']
|
device_channel_names = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1', 'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6', 'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1', 'T7', 'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ', 'PZ', 'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1', 'label', 'label_tag']
|
||||||
channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
|
device_channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
|
||||||
|
|||||||
163
datamock.py
163
datamock.py
@@ -1,163 +0,0 @@
|
|||||||
import zmq
|
|
||||||
import numpy as np
|
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# ========== 参数配置 ==========
|
|
||||||
FS = 250 # 采样率 Hz
|
|
||||||
N_SAMPLES_PER_PKT = 5 # 每包采样点数
|
|
||||||
N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
|
|
||||||
EEG_FREQ = 10 # EEG 正弦波频率 Hz
|
|
||||||
EEG_AMP = 100.0 # EEG 幅值 100μV
|
|
||||||
LABEL_INTERVAL = 5 # 标签间隔秒数
|
|
||||||
SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
|
||||||
|
|
||||||
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
|
|
||||||
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
|
|
||||||
|
|
||||||
|
|
||||||
def build_packet(global_sample_idx):
|
|
||||||
"""
|
|
||||||
生成一包 [5, 66] 的 float64 数据
|
|
||||||
:param global_sample_idx: 当前包第一个采样点在全局序列中的索引 (从 0 开始)
|
|
||||||
:return: np.ndarray shape [5, 66]
|
|
||||||
"""
|
|
||||||
# 当前包内 5 个采样点对应的时间(秒)
|
|
||||||
t = (global_sample_idx + np.arange(N_SAMPLES_PER_PKT)) / FS
|
|
||||||
|
|
||||||
# Ch0-63: EEG 10Hz 正弦波,幅值 100μV
|
|
||||||
# t shape [5,],sin 乘以标量后仍是 [5,],需要 reshape 为 [5,1] 再广播到 64 通道
|
|
||||||
eeg = (EEG_AMP * np.sin(2 * np.pi * EEG_FREQ * t)).reshape(N_SAMPLES_PER_PKT, 1) # [5, 1]
|
|
||||||
eeg = np.tile(eeg, (1, 64)) # [5, 64]
|
|
||||||
|
|
||||||
# Ch64: 标签值通道,初始化为 0
|
|
||||||
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
|
||||||
|
|
||||||
# Ch65: 标签序号通道,初始化为 0
|
|
||||||
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
|
||||||
|
|
||||||
# 拼成 [5, 66]
|
|
||||||
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float64)
|
|
||||||
return packet
|
|
||||||
|
|
||||||
|
|
||||||
def should_send_label(global_sample_idx):
|
|
||||||
"""
|
|
||||||
判断当前包是否包含标签触发点(每 5s 的最后一个采样点)
|
|
||||||
采样点索引从 0 开始,每 5s = 1250 个采样点
|
|
||||||
最后一个采样点索引: 1249, 2499, 3749, ...
|
|
||||||
由于每包 5 个采样点,标签点落在包内的最后一个采样点位置
|
|
||||||
即当前包起始索引 global_sample_idx 必须使得:
|
|
||||||
global_sample_idx <= 标签点索引 < global_sample_idx + N_SAMPLES_PER_PKT
|
|
||||||
也就是 global_sample_idx <= 1249 < global_sample_idx + 5
|
|
||||||
即 global_sample_idx = 1245, 2495, 3745, ...
|
|
||||||
即 global_sample_idx = n * LABEL_INTERVAL * FS - N_SAMPLES_PER_PKT
|
|
||||||
"""
|
|
||||||
samples_per_interval = LABEL_INTERVAL * FS
|
|
||||||
# 检查当前包是否包含 interval 的最后一个采样点
|
|
||||||
# 标签点索引 = n * 1250 - 1,当 global_sample_idx = n*1250-5 时,标签在包内索引 4
|
|
||||||
return (global_sample_idx + N_SAMPLES_PER_PKT - 1) % samples_per_interval == samples_per_interval - 1
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
ctx = zmq.Context()
|
|
||||||
sock = ctx.socket(zmq.DEALER)
|
|
||||||
sock.connect(SERVER_ADDR)
|
|
||||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}")
|
|
||||||
|
|
||||||
# 后台消费线程:持续 recv 从 ROUTER 返回的数据,避免 server 发送队列积压
|
|
||||||
recv_count = [0]
|
|
||||||
stop_recv = threading.Event()
|
|
||||||
|
|
||||||
def consumer_thread():
|
|
||||||
"""消费线程:阻塞 recv,丢弃收到的数据,仅用于清空 ROUTER 发送队列"""
|
|
||||||
while not stop_recv.is_set():
|
|
||||||
try:
|
|
||||||
frames = sock.recv_multipart(zmq.NOBLOCK)
|
|
||||||
recv_count[0] += 1
|
|
||||||
# 收到的格式: [identity, '', filtered_data_bytes]
|
|
||||||
if recv_count[0] % 500 == 0:
|
|
||||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已丢弃 {recv_count[0]} 帧滤波数据")
|
|
||||||
except zmq.Again:
|
|
||||||
time.sleep(0.01)
|
|
||||||
except zmq.error.Again: # 兼容旧版
|
|
||||||
time.sleep(0.01)
|
|
||||||
|
|
||||||
consumer = threading.Thread(target=consumer_thread, daemon=True)
|
|
||||||
consumer.start()
|
|
||||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已启动(daemon)")
|
|
||||||
|
|
||||||
global_sample_idx = 0 # 全局采样点计数器
|
|
||||||
label_type = 1 # 当前标签类型: 1 或 2
|
|
||||||
label1_count = 0 # label=1 的序号计数器
|
|
||||||
label2_count = 0 # label=2 的序号计数器
|
|
||||||
packet_count = 0 # 已发送包数
|
|
||||||
|
|
||||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...")
|
|
||||||
print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms")
|
|
||||||
print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV")
|
|
||||||
print(f" 标签: 每 {LABEL_INTERVAL}s 末尾采样点触发 | label 1/2 交替")
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
t_start = time.perf_counter()
|
|
||||||
|
|
||||||
# 构建当前包
|
|
||||||
packet = build_packet(global_sample_idx)
|
|
||||||
|
|
||||||
# 检查是否需要放置标签
|
|
||||||
if should_send_label(global_sample_idx):
|
|
||||||
if label_type == 1:
|
|
||||||
label1_count += 1
|
|
||||||
label_value = 1
|
|
||||||
label_number = label1_count
|
|
||||||
else:
|
|
||||||
label2_count += 1
|
|
||||||
label_value = 2
|
|
||||||
label_number = label2_count
|
|
||||||
|
|
||||||
# 标签放在当前包最后一个采样点(索引 4)
|
|
||||||
packet[4, 64] = label_value
|
|
||||||
packet[4, 65] = label_number
|
|
||||||
|
|
||||||
ts = datetime.now().strftime('%H:%M:%S')
|
|
||||||
print(f"[{ts}] 标签触发: label={label_value}, 序号={label_number} "
|
|
||||||
f"(global_sample_idx={global_sample_idx})")
|
|
||||||
|
|
||||||
# 交替标签类型
|
|
||||||
label_type = 2 if label_type == 1 else 1
|
|
||||||
|
|
||||||
# 发送: multipart 3帧 [identity, '', data]
|
|
||||||
# 使用标准格式(3帧),ROUTER 会自动附加 ZMQ 分配的客户端身份
|
|
||||||
sock.send_multipart([
|
|
||||||
b'',
|
|
||||||
packet.tobytes()
|
|
||||||
])
|
|
||||||
|
|
||||||
# 每 50 包打印一次进度
|
|
||||||
if packet_count % 50 == 0:
|
|
||||||
ts = datetime.now().strftime('%H:%M:%S')
|
|
||||||
print(f"[{ts}] 已发送 {packet_count} 包 (global_sample_idx={global_sample_idx})")
|
|
||||||
|
|
||||||
global_sample_idx += N_SAMPLES_PER_PKT
|
|
||||||
packet_count += 1
|
|
||||||
|
|
||||||
# 精确控制发送节奏: 等待到 PKT_INTERVAL 秒
|
|
||||||
elapsed = time.perf_counter() - t_start
|
|
||||||
sleep_time = PKT_INTERVAL - elapsed
|
|
||||||
if sleep_time > 0:
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 停止发送,共发送 {packet_count} 包")
|
|
||||||
finally:
|
|
||||||
stop_recv.set()
|
|
||||||
consumer.join(timeout=2)
|
|
||||||
sock.close()
|
|
||||||
ctx.term()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
56
logs/log.py
56
logs/log.py
@@ -1,44 +1,51 @@
|
|||||||
|
# log.py
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
import inspect # 新增导入
|
|
||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
|
|
||||||
console_output = IniRead('system', 'console_output', '1')
|
console_output = IniRead('system', 'console_output', '1')
|
||||||
log_level = IniRead('system', 'algo_log_level', 'INFO')
|
log_level = IniRead('system', 'algo_log_level', 'INFO')
|
||||||
|
|
||||||
|
# 新增:日志去重缓存,key为日志内容,value为是否已打印
|
||||||
log_once_cache = set()
|
log_once_cache = set()
|
||||||
|
|
||||||
# 缓存已经创建过的logger,避免重复创建handler
|
|
||||||
logger_cache = {}
|
|
||||||
|
|
||||||
def init_module_logger(logger_name):
|
def init_module_logger():
|
||||||
log_dir = './logs/'
|
"""
|
||||||
|
初始化指定模块的日志器
|
||||||
|
:return: 对应模块的logger实例
|
||||||
|
"""
|
||||||
|
# 缓存命中则直接返回
|
||||||
|
log_dir = './logs/' # 确保日志目录存在
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log')
|
log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log')
|
||||||
|
|
||||||
# 已创建直接返回
|
# 初始化logger
|
||||||
if logger_name in logger_cache:
|
logger = logging.getLogger('decoderLogger')
|
||||||
return logger_cache[logger_name]
|
|
||||||
|
|
||||||
logger = logging.getLogger(logger_name)
|
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
if logger.handlers:
|
if logger.handlers:
|
||||||
logger_cache[logger_name] = logger
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
# 设置日志轮转,最大10个文件,每个10MB
|
||||||
file_handler = RotatingFileHandler(
|
file_handler = RotatingFileHandler(
|
||||||
log_file,
|
log_file,
|
||||||
maxBytes=10*1024*1024,
|
maxBytes=10*1024*1024,
|
||||||
backupCount=10,
|
backupCount=10,
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 日志格式
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
)
|
)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.setLevel(log_level)
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
if console_output:
|
if console_output:
|
||||||
@@ -46,26 +53,27 @@ def init_module_logger(logger_name):
|
|||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
logger_cache[logger_name] = logger
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
def algo_log(content, level="INFO", record_once=False):
|
def algo_log(content, level="INFO", record_once=False):
|
||||||
# 向上回溯1层栈,拿到调用algo_log的代码文件信息
|
"""
|
||||||
frame = inspect.currentframe().f_back
|
通用日志函数,支持按模块输出到不同日志文件
|
||||||
file_path = frame.f_code.co_filename
|
:param content: 日志内容
|
||||||
# 提取py文件名(不带后缀/带后缀自选)
|
:param level: 日志级别(DEBUG/INFO/WARNING/ERROR/FATAL)
|
||||||
file_name = os.path.basename(file_path) # 例:zmqServer.py
|
:param record_once: 是否只打印一次该日志内容,默认False
|
||||||
# file_name = os.path.splitext(os.path.basename(file_path))[0] # 例:zmqServer
|
"""
|
||||||
|
# 初始化模块日志器
|
||||||
logger = init_module_logger(file_name)
|
logger = init_module_logger()
|
||||||
|
|
||||||
|
# 新增:处理只打印一次的逻辑
|
||||||
if record_once:
|
if record_once:
|
||||||
|
# 生成唯一标识(可根据需要调整,比如拼接level增强唯一性)
|
||||||
log_key = f"{level.upper()}_{content}"
|
log_key = f"{level.upper()}_{content}"
|
||||||
if log_key in log_once_cache:
|
if log_key in log_once_cache:
|
||||||
return
|
return # 已打印过,直接返回
|
||||||
log_once_cache.add(log_key)
|
log_once_cache.add(log_key) # 未打印过,加入缓存
|
||||||
|
|
||||||
|
# 根据级别输出日志
|
||||||
level_upper = level.upper()
|
level_upper = level.upper()
|
||||||
if level_upper == "DEBUG":
|
if level_upper == "DEBUG":
|
||||||
logger.debug(content)
|
logger.debug(content)
|
||||||
@@ -75,5 +83,5 @@ def algo_log(content, level="INFO", record_once=False):
|
|||||||
logger.error(content)
|
logger.error(content)
|
||||||
elif level_upper == "FATAL":
|
elif level_upper == "FATAL":
|
||||||
logger.fatal(content)
|
logger.fatal(content)
|
||||||
else:
|
else: # 默认INFO级别
|
||||||
logger.info(content)
|
logger.info(content)
|
||||||
BIN
online_Models/Model_2025-11-15-11-11-50.pth
Normal file
BIN
online_Models/Model_2025-11-15-11-11-50.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2025-11-17-16-55-25.pth
Normal file
BIN
online_Models/Model_2025-11-17-16-55-25.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2025-11-18-10-15-35.pth
Normal file
BIN
online_Models/Model_2025-11-18-10-15-35.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-01-08-14-55-10.pth
Normal file
BIN
online_Models/Model_2026-01-08-14-55-10.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-14-21-38.pth
Normal file
BIN
online_Models/Model_2026-05-26-14-21-38.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-15-26-09.pth
Normal file
BIN
online_Models/Model_2026-05-26-15-26-09.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-15-44-21.pth
Normal file
BIN
online_Models/Model_2026-05-26-15-44-21.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-16-06-24.pth
Normal file
BIN
online_Models/Model_2026-05-26-16-06-24.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-16-30-12.pth
Normal file
BIN
online_Models/Model_2026-05-26-16-30-12.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-16-44-52.pth
Normal file
BIN
online_Models/Model_2026-05-26-16-44-52.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-30-13-08-50.pth
Normal file
BIN
online_Models/Model_2026-05-30-13-08-50.pth
Normal file
Binary file not shown.
@@ -6,33 +6,32 @@ import time
|
|||||||
from Decoder import Decoder_main
|
from Decoder import Decoder_main
|
||||||
from PubLibrary.RunOnce import is_program_running
|
from PubLibrary.RunOnce import is_program_running
|
||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
from logs.log import algo_log
|
|
||||||
|
|
||||||
def get_device_info(device_type):
|
def get_device_info(device_type):
|
||||||
|
|
||||||
|
|
||||||
section = f'device_type_{device_type}'
|
section = f'device_type_{device_type}'
|
||||||
device_info = {
|
device_info = {
|
||||||
'sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250,
|
'device_sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250,
|
||||||
'frame_points': int(IniRead(section, 'frame_points')) if IniRead(section, 'frame_points') is not None else 5,
|
|
||||||
'channel_nums': int(IniRead(section, 'channel_nums')) if IniRead(section, 'channel_nums') is not None else 66,
|
''
|
||||||
'channel_names': IniRead(section, 'channel_names') if IniRead(section, 'channel_names') is not None else None,
|
|
||||||
'channel_index': IniRead(section, 'channel_index') if IniRead(section, 'channel_index') is not None else None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return device_info
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if not is_program_running():
|
if not is_program_running():
|
||||||
# 解析命令行参数
|
# 解析命令行参数
|
||||||
# parser = argparse.ArgumentParser(description="EEG Decoder Application")
|
parser = argparse.ArgumentParser(description="EEG Decoder Application")
|
||||||
# parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type")
|
parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type")
|
||||||
# parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
|
# parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
|
||||||
# parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
|
# parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
|
||||||
# parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
|
# parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
|
||||||
# parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
|
# parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
|
||||||
# args = parser.parse_args()
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
device_info= get_device_info(args.device_type)
|
||||||
|
|
||||||
|
|
||||||
|
decoder = Decoder_main(device_info=device_info)
|
||||||
# decoder.connect(
|
# decoder.connect(
|
||||||
# device_type=args.device_type,
|
# device_type=args.device_type,
|
||||||
# device_host=args.device_host,
|
# device_host=args.device_host,
|
||||||
@@ -41,10 +40,6 @@ if __name__ == "__main__":
|
|||||||
# upper_port=args.upper_port
|
# upper_port=args.upper_port
|
||||||
# )
|
# )
|
||||||
|
|
||||||
device_info= get_device_info(1)
|
|
||||||
algo_log(f"device_info: {device_info}", level="DEBUG")
|
|
||||||
decoder = Decoder_main(device_info=device_info)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decoder.start()
|
decoder.start()
|
||||||
while not decoder.zmqServer.IsExitApp:
|
while not decoder.zmqServer.IsExitApp:
|
||||||
|
|||||||
Reference in New Issue
Block a user