Compare commits
2 Commits
9690971f43
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34c9115258 | ||
|
|
69b2802895 |
11
.gitignore
vendored
11
.gitignore
vendored
@@ -2,14 +2,10 @@
|
||||
__pycache__/
|
||||
|
||||
# Distribution / packaging
|
||||
release/
|
||||
build/
|
||||
dist/
|
||||
dist_nuitka/
|
||||
upperHost_stim/
|
||||
.vscode/
|
||||
#!upperHost_stim/MI_headless.py
|
||||
#!upperHost_stim/ssmvep_headless.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
@@ -28,8 +24,7 @@ venv.bak/
|
||||
*.xlsx
|
||||
*.mat
|
||||
*.json
|
||||
*.txt
|
||||
*.pth
|
||||
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
|
||||
|
||||
498
Decoder.py
498
Decoder.py
@@ -1,7 +1,6 @@
|
||||
import ast
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime
|
||||
import multiprocessing as mp
|
||||
@@ -11,59 +10,58 @@ import torch
|
||||
from queue import Empty
|
||||
from scipy import signal
|
||||
from torch.autograd import Variable
|
||||
# from Device.SunnyLinker import SunnyLinker64
|
||||
from Device.SunnyLinker import SunnyLinker64
|
||||
from SSMVEP.algorithm.tdca import TDCA
|
||||
from SSMVEP.algorithm.base import generate_cca_references
|
||||
# from concentration.algorithm.calculate_focus import Calculate
|
||||
# from blinkdetection.algorithm.eye_detection import blink_detection
|
||||
from concentration.algorithm.calculate_focus import Calculate
|
||||
from blinkdetection.algorithm.eye_detection import blink_detection
|
||||
from Zmq.zmqServer import zmqServer
|
||||
from Zmq.zmqClient import zmqClient
|
||||
from MI.Algorithm.conformer_2class import onlineTrain
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
from logs.log import algo_log
|
||||
from SSVEP.dwfbcca import FbccaDw
|
||||
# from Tools.plot_MI_EEG import plotMain
|
||||
from Tools.plot_MI_EEG import plotMain
|
||||
from collections import deque
|
||||
from Zmq.filterProcess import SlidingFilter
|
||||
|
||||
save_train_data = int(IniRead('system', 'save_train_data', 0))
|
||||
|
||||
def get_root_path():
|
||||
"""
|
||||
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
|
||||
"""
|
||||
if getattr(sys, 'frozen', False):
|
||||
# 打包后:返回 exe 所在目录
|
||||
return os.path.dirname(sys.executable)
|
||||
else:
|
||||
# 开发时:返回 py 文件所在目录
|
||||
return os.path.dirname(os.path.abspath(__file__))
|
||||
MODEL_FOLDER = "online_Models"
|
||||
|
||||
|
||||
class Decoder_main(threading.Thread):
|
||||
def __init__(self, device_info=None):
|
||||
class Decoder_main(threading.Thread, device_type):
|
||||
def __init__(self, device_type=None):
|
||||
threading.Thread.__init__(self)
|
||||
self.device_info = device_info
|
||||
self.Runing=True
|
||||
self.decoder = None
|
||||
|
||||
self.fs = 250 # 采样率
|
||||
self.energy = 0 # 电量
|
||||
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||
self.decoder_class = None #解码器类别
|
||||
|
||||
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
||||
self.device_info = {
|
||||
'device_type': None,
|
||||
'sample_rate': None,
|
||||
'channel_num': None,
|
||||
}
|
||||
|
||||
def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
|
||||
self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
|
||||
_device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
|
||||
_device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
|
||||
_upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
|
||||
_upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
|
||||
|
||||
self.zmqServer = zmqServer(device_info=self.device_info)
|
||||
self.zmqServer.start() # 启动ZMQ接收线程
|
||||
if self.DeviceType == 1:
|
||||
self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp')
|
||||
self.thread_data_server.host = _device_host
|
||||
self.thread_data_server.port = _device_port
|
||||
|
||||
self.sliding_filter = SlidingFilter(
|
||||
ring_buffer=self.zmqServer.filterBuffer,
|
||||
n_chan=self.zmqServer.device_info['channel_nums'],
|
||||
srate=self.zmqServer.device_info['sample_rate']
|
||||
)
|
||||
self.thread_data_server.toUv = True
|
||||
self.thread_data_server.start()
|
||||
|
||||
# 注册滤波结果回调(示例:打印数据形状)
|
||||
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
|
||||
# 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机
|
||||
self.sliding_filter.set_beta_broadcast_callback(lambda v: self.zmqServer.broadcast_message('beta_psd', v))
|
||||
self.zmqServer = zmqServer()
|
||||
self.zmqServer.start()
|
||||
|
||||
self.zmqClient = zmqClient(_upper_host, _upper_port)
|
||||
self.zmqClient.set_zmq_server(self.zmqServer)
|
||||
self.zmqClient.connect()
|
||||
|
||||
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
||||
# data: (chans, samples)
|
||||
@@ -78,44 +76,45 @@ class Decoder_main(threading.Thread):
|
||||
:return:
|
||||
'''
|
||||
self.decoder_class = decoder_class
|
||||
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||
if decoder_class == 'ssvep' or decoder_class == 'pvs':
|
||||
self.n_chan = 8
|
||||
# self.thread_data_server.interval_inited = False
|
||||
self.thread_data_server.interval_inited = False
|
||||
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
||||
self.ListFreq = self.zmqServer.targetFreqs
|
||||
self.num_target = len(self.ListFreq)
|
||||
if self.num_target == 0:
|
||||
return
|
||||
# 初始化对象 二代算法
|
||||
self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
|
||||
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5,
|
||||
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
|
||||
# frequence band
|
||||
self.dw.filterFrequenceBank()
|
||||
self.dw.setNotchFilterPara()
|
||||
self.calculateCount = 0
|
||||
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
|
||||
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs),
|
||||
5)
|
||||
self.dw.filterInit()
|
||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||
|
||||
elif decoder_class == 'ssmvep':
|
||||
self.zmqServer.interval_init(decoder_class)
|
||||
self.thread_data_server.interval_init(decoder_class)
|
||||
self.n_chan = 8
|
||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
|
||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||
self.single_train = 10 # 单类别数量
|
||||
self.num_target = 2 # 分类目标数目
|
||||
self.list_freqs = np.array([8, 9]) # 刺激频率
|
||||
self.list_phase = np.array([0, 0]) # 相位
|
||||
self.tdca = TDCA(padding_len=5, n_components=1)
|
||||
self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
|
||||
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length,
|
||||
phases=self.list_phase, n_harmonics=5)
|
||||
self.parameter_init(5,45)
|
||||
|
||||
elif decoder_class == 'mi' or decoder_class == 'ma':
|
||||
self.zmqServer.interval_init(decoder_class)
|
||||
self.thread_data_server.interval_init(decoder_class)
|
||||
self.n_chan = 21
|
||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度4s,# 精确到小数点后6位
|
||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||
self.single_train = 40 # 单类别数量
|
||||
self.num_target = 2 # 分类目标数目
|
||||
|
||||
@@ -127,7 +126,7 @@ class Decoder_main(threading.Thread):
|
||||
# self.win_len = 10
|
||||
# self.win_step = 1
|
||||
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
|
||||
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
|
||||
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len)
|
||||
# self.interval_epoch = [0, 1]
|
||||
# self.parameter_init(2, 40)
|
||||
# # self.eegQueue moved to Calculate class
|
||||
@@ -139,8 +138,8 @@ class Decoder_main(threading.Thread):
|
||||
# self.total_samples = 0 # 总采样点数
|
||||
# self.window_ms = 600 # 检测窗口大小 (ms)
|
||||
# self.step_ms = 100 # 滑动步长 (ms)
|
||||
# self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
|
||||
# self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
|
||||
# self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点
|
||||
# self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点
|
||||
# self.buffer_size = self.window_samples + self.step_samples * 5
|
||||
# self.fp1_buffer = deque(maxlen=self.buffer_size)
|
||||
# self.fp2_buffer = deque(maxlen=self.buffer_size)
|
||||
@@ -154,11 +153,11 @@ class Decoder_main(threading.Thread):
|
||||
# self.double_blink_events = [] # 连续眨眼事件记录
|
||||
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
|
||||
# self.blink_events = []
|
||||
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
|
||||
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band')
|
||||
|
||||
def parameter_init(self,bandPass_low,bandPass_high):
|
||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
|
||||
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
|
||||
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息
|
||||
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch
|
||||
self.trainData = [] #训练数据
|
||||
self.trainLabel = [] #训练标签
|
||||
self.plotData = [] #报告分析数据
|
||||
@@ -166,12 +165,12 @@ class Decoder_main(threading.Thread):
|
||||
self.currentLabel = -1 #刺激界面当前显示的训练标签
|
||||
self.train_started = False #是否开始训练模型
|
||||
self.load_model = False # 调用模型是否完成的标志
|
||||
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
||||
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
||||
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
|
||||
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
||||
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
||||
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
filePath = './online_Models/'
|
||||
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
|
||||
os.remove(old_pth)
|
||||
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
||||
self.mp_data_queue = mp.Queue()
|
||||
self.mp_result_queue = mp.Queue()
|
||||
@@ -188,13 +187,8 @@ class Decoder_main(threading.Thread):
|
||||
|
||||
def run(self):
|
||||
while self.Runing:
|
||||
# 当滤波数据大于5秒时,启动滤波线程
|
||||
if not self.sliding_filter.is_alive() and self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5:
|
||||
algo_log("启动滤波线程", level="DEBUG")
|
||||
self.sliding_filter.start()
|
||||
|
||||
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
|
||||
algo_log(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}", level="DEBUG")
|
||||
print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}")
|
||||
self.zmqServer.decoder_switch = False
|
||||
self.zmqServer.changeTarget = False
|
||||
self.reset_state() # 切换前先统一清理旧状态
|
||||
@@ -202,97 +196,150 @@ class Decoder_main(threading.Thread):
|
||||
|
||||
# 同步信息
|
||||
if self.zmqServer.state_mode == 'sync':
|
||||
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||
self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||
self.zmqServer.state_mode = 'rest'
|
||||
# 状态异常,报告上位机
|
||||
if self.status_code != self.thread_data_server.status_code:
|
||||
self.status_code = self.thread_data_server.status_code
|
||||
self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||
print('status code')
|
||||
|
||||
# 返回电量
|
||||
if self.energy != self.thread_data_server.energy:
|
||||
self.energy = self.thread_data_server.energy
|
||||
self.zmqClient.send_to_all('energy', int(self.energy))
|
||||
print('energy')
|
||||
|
||||
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||
self.thread_data_server.Impedance(True)
|
||||
print('Impedance')
|
||||
self.zmqServer.open_Impedance = -1
|
||||
elif self.zmqServer.open_Impedance == False:
|
||||
self.thread_data_server.Impedance(False)
|
||||
self.zmqServer.open_Impedance = -1
|
||||
|
||||
if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||
# print(self.zmqServer.get_Impedance)
|
||||
# print(self.thread_data_server.GetDataLenCount())
|
||||
if self.thread_data_server.GetDataLenCount() > 250:
|
||||
Impe_data = self.thread_data_server.getData(250)
|
||||
# 计算阻抗
|
||||
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
|
||||
self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||
else:
|
||||
pass
|
||||
if self.zmqServer.getReport: #返回训练报告内容
|
||||
self.zmqServer.getReport = False
|
||||
allData = np.array(self.plotData)
|
||||
allLabel = np.array(self.plotLabel) + 1
|
||||
nTrials = min(len(allLabel),len(allData))
|
||||
if nTrials < 30:
|
||||
self.zmqClient.send_to_all('miReport',0)
|
||||
else:
|
||||
allData = allData[:nTrials]
|
||||
allLabel = allLabel[:nTrials]
|
||||
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
||||
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
||||
compare_names = ['C3', 'CZ', 'C4']
|
||||
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
|
||||
fs=self.fs)
|
||||
self.zmqClient.send_to_all('miReport',miReport)
|
||||
|
||||
|
||||
# --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 ---
|
||||
try:
|
||||
if self.zmqServer.open_Impedance:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||
self.decoder_SSVEP()
|
||||
elif self.decoder_class == 'ssmvep':
|
||||
self.decoder_SSMVEP()
|
||||
elif self.decoder_class == 'mi':
|
||||
self.decoder_MI()
|
||||
elif self.decoder_class == 'concentration':
|
||||
self.decoder_concentration()
|
||||
elif self.decoder_class == 'blink':
|
||||
self.decoder_blink()
|
||||
else:
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
if self.thread_data_server.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
continue;
|
||||
self.thread_data_server.getData(25)
|
||||
except Exception as e:
|
||||
algo_log(f"Decoder Loop Error: {e}")
|
||||
print(f"Decoder Loop Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
||||
|
||||
def decoder_SSVEP(self):
|
||||
if self.zmqServer.StartDecode:
|
||||
self.zmqServer.StartDecode = False
|
||||
self.decodingSteps = 1
|
||||
self.zmqServer.paradigmBuffer.resetAllPara()
|
||||
algo_log('启动SSVEP预测', level="DEBUG")
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
|
||||
self.thread_data_server.ResetAll()
|
||||
print('启动预测')
|
||||
if self.thread_data_server.GetDataLenCount() < 50:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
|
||||
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
||||
return
|
||||
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
|
||||
# algo_log(f"SSVEP取出的:{data.shape}, data = {data[:, :10]}", level="DEBUG")
|
||||
data = self.thread_data_server.getDataViaSSVEP(50)
|
||||
data = data[:self.n_chan, :]
|
||||
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||
self.dw.warmFilter(data) # 预热
|
||||
self.decodingSteps = 2
|
||||
algo_log('SSVEP预热数据完成。开始预测', level="DEBUG")
|
||||
print('预热数据完成。开始预测')
|
||||
return
|
||||
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
|
||||
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
|
||||
self.calculateCount += 1
|
||||
if choosenNum != -1 and self.is_valid_signal(data):
|
||||
self.decodingSteps = 3
|
||||
algo_log('SSVEP预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount), level="DEBUG")
|
||||
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
|
||||
self.calculateCount = 0
|
||||
if self.decodingSteps == 3: # 发送解码后的信息
|
||||
self.zmqServer.broadcast_message('result', int(choosenNum))
|
||||
self.zmqClient.send_to_all('result', int(choosenNum))
|
||||
self.decodingSteps = 0
|
||||
algo_log('SSVEP发送给界面完成。', level="DEBUG")
|
||||
print('发送给界面完成。')
|
||||
|
||||
def decoder_SSMVEP(self):
|
||||
'''模型训练'''
|
||||
if self.load_model == False and all(
|
||||
self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练完成
|
||||
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
|
||||
self.trainData = np.array(self.trainData)
|
||||
self.trainLabel = np.array(self.trainLabel)
|
||||
algo_log(f"开始SSMVEP模型训练,数据形状:{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG")
|
||||
if save_train_data == 1:
|
||||
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
save_path = f"{now_str}.npz"
|
||||
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
|
||||
print(np.shape(self.trainData), (self.trainLabel))
|
||||
# 保存多个数组到文件
|
||||
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel)
|
||||
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
||||
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
||||
now = datetime.now()
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
algo_log(f"SSMVEP模型训练完成,时间:{formatted_time}", level="DEBUG")
|
||||
print('模型训练完成', formatted_time)
|
||||
self.load_model = True
|
||||
self.zmqServer.broadcast_message('paradigm', 1)
|
||||
self.zmqClient.send_to_all('paradigm', 1)
|
||||
|
||||
'''训练阶段采集数据'''
|
||||
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||
self.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||
if self.zmqServer.StartTrain:
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||
algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
self.zmqServer.StartTrain = False
|
||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||
self.train_epoch[1] \
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
||||
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
|
||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
|
||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]]
|
||||
print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||
self.trainLabel, list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
else:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
|
||||
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
||||
if self.load_model == False: # 模型尚未训练完成
|
||||
@@ -303,47 +350,45 @@ class Decoder_main(threading.Thread):
|
||||
self.zmqServer.StartDecode = False
|
||||
now = datetime.now()
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
print('启动预测 ', formatted_time)
|
||||
|
||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
# algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG")
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
|
||||
algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
data = self.thread_data_server.get_SSMVEPData() # 读取全部数据
|
||||
print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx])
|
||||
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
||||
data = data[:,
|
||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||
pad_eeg_test = np.zeros(
|
||||
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
|
||||
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
|
||||
(data.shape[0], int((self.sample_length + 0.1) * self.fs)))
|
||||
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data
|
||||
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
||||
if isinstance(choosenNum, np.ndarray):
|
||||
choosenNum = choosenNum[0]
|
||||
algo_log(f"结果:{choosenNum}, rho: {sorted(features_2[0])[-1] - sorted(features_2[0])[-2]}", level="DEBUG")
|
||||
self.zmqServer.broadcast_message('result', int(choosenNum))
|
||||
algo_log("SSMVEP发送给界面完成。", level="DEBUG")
|
||||
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
|
||||
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
|
||||
self.zmqClient.send_to_all('result', int(choosenNum))
|
||||
print('发送给界面完成。')
|
||||
else: # 休息状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
if self.thread_data_server.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
self.thread_data_server.getData(25)
|
||||
|
||||
def decoder_MI(self):
|
||||
'''模型训练'''
|
||||
if self.train_started == False and all(
|
||||
self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练
|
||||
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
|
||||
self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||
self.train_started = True
|
||||
self.trainData = np.array(self.trainData)
|
||||
self.trainLabel = np.array(self.trainLabel)
|
||||
algo_log(f"MI开始训练,训练集:{np.shape(self.trainData)},标签shape:{np.shape(self.trainLabel)}", level="DEBUG")
|
||||
if save_train_data == 1:
|
||||
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
save_path = f"{now_str}.npz"
|
||||
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
|
||||
self.trainLabel = np.array(self.trainLabel) + 1
|
||||
# print('训练集:',np.shape(self.trainData), (self.trainLabel))
|
||||
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
|
||||
p.start()
|
||||
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
|
||||
@@ -354,7 +399,7 @@ class Decoder_main(threading.Thread):
|
||||
try:
|
||||
result = self.mp_result_queue.get_nowait()
|
||||
if result['status'] == 'success':
|
||||
algo_log("MI模型训练完成,加载新模型", level="DEBUG")
|
||||
print("模型训练完成,加载新模型")
|
||||
# 调用模型
|
||||
self.model = torch.load(self.modelPath, weights_only=False)
|
||||
self.model.eval()
|
||||
@@ -365,61 +410,63 @@ class Decoder_main(threading.Thread):
|
||||
with torch.no_grad():
|
||||
_ = self.model(warmup_data)
|
||||
self.load_model = True
|
||||
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
|
||||
self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机
|
||||
else:
|
||||
algo_log("MI训练失败: " + result['msg'], level="DEBUG")
|
||||
print("训练失败:", result['msg'])
|
||||
except Empty:
|
||||
pass # 还没完成
|
||||
except Exception as e:
|
||||
algo_log("MI模型训练失败: " + str(e), level="DEBUG")
|
||||
print('模型调用失败: ', e)
|
||||
|
||||
'''训练阶段采集数据'''
|
||||
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
|
||||
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||
self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||
self.currentLabel = self.zmqServer.currentLabel # 同步当前标签
|
||||
algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
|
||||
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
||||
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
if self.zmqServer.StartTrain:
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
self.zmqServer.StartTrain = False
|
||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
||||
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
|
||||
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
|
||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
# algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
|
||||
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||
print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
||||
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||
list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||
print('训练集:', np.shape(self.trainData))
|
||||
self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||
self.plotLabel.append(self.currentLabel)
|
||||
else:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
|
||||
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
||||
if self.zmqServer.StartDecode:
|
||||
self.zmqServer.StartDecode = False
|
||||
now = datetime.now()
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
|
||||
print('启动预测 ', formatted_time)
|
||||
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
|
||||
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
originalData = self.thread_data_server.get_MIData() # 读取全部数据
|
||||
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx])
|
||||
start = time.time()
|
||||
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
||||
data = data[:,
|
||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||
self.plotData.append(
|
||||
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||
originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||
|
||||
test_data = data[np.newaxis, np.newaxis, :, :]
|
||||
test_data = torch.from_numpy(test_data)
|
||||
@@ -428,40 +475,134 @@ class Decoder_main(threading.Thread):
|
||||
Cls = self.model(test_data)
|
||||
y_pred = torch.max(Cls, 1)[1]
|
||||
self.plotLabel.append(int(y_pred.item()))
|
||||
algo_log(f"MI运动意图识别: {y_pred}")
|
||||
self.zmqServer.broadcast_message('result', int(y_pred.item()))
|
||||
print('运动意图识别: ', y_pred)
|
||||
self.zmqClient.send_to_all('result', int(y_pred.item()))
|
||||
end = time.time()
|
||||
algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。')
|
||||
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
||||
else: # 休息状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
if self.thread_data_server.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
self.thread_data_server.getData(25)
|
||||
|
||||
# def decoder_concentration(self):
|
||||
# if self.zmqServer.state_mode == 'predict':
|
||||
# if self.zmqServer.StartDecode:
|
||||
# self.zmqServer.StartDecode = False
|
||||
# self.thread_data_server.ResetAll()
|
||||
# now = datetime.now()
|
||||
# formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
# print('启动专注力预测 ', formatted_time)
|
||||
# if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
|
||||
# time.sleep(0.005)
|
||||
# return
|
||||
# if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
||||
# return
|
||||
# data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
|
||||
# result = self.calculate.queueOpt(data)
|
||||
# if result is not None:
|
||||
# self.zmqClient.send_to_all('result', int(result))
|
||||
# else: # 休息状态
|
||||
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
# if self.thread_data_server.GetDataLenCount() < 25:
|
||||
# time.sleep(0.005)
|
||||
# return
|
||||
# self.thread_data_server.getData(25)
|
||||
def decoder_concentration(self):
|
||||
if self.zmqServer.state_mode == 'predict':
|
||||
if self.zmqServer.StartDecode:
|
||||
self.zmqServer.StartDecode = False
|
||||
self.thread_data_server.ResetAll()
|
||||
now = datetime.now()
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
print('启动专注力预测 ', formatted_time)
|
||||
if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果
|
||||
time.sleep(0.005)
|
||||
return
|
||||
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
||||
return
|
||||
data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据
|
||||
result = self.calculate.queueOpt(data)
|
||||
if result is not None:
|
||||
self.zmqClient.send_to_all('result', int(result))
|
||||
else: # 休息状态
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
if self.thread_data_server.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.thread_data_server.getData(25)
|
||||
|
||||
#### Blink detection #####
|
||||
def check_double_blink(self, current_time):
|
||||
"""
|
||||
检查是否检测到连续两次眨眼
|
||||
@param current_time: 当前眨眼时间戳
|
||||
@return: True表示检测到连续两次眨眼
|
||||
"""
|
||||
if len(self.blink_timestamps) < 2:
|
||||
return False
|
||||
|
||||
# 检查是否在去抖期内
|
||||
if self.last_double_blink_time > 0:
|
||||
time_since_last_double_blink = current_time - self.last_double_blink_time
|
||||
if time_since_last_double_blink < self.double_blink_jitter:
|
||||
return False # 在去抖期内,忽略连续眨眼检测
|
||||
last_time = self.blink_timestamps[-1] # 当前眨眼
|
||||
prev_time = self.blink_timestamps[-2] # 上次眨眼
|
||||
|
||||
interval = last_time - prev_time
|
||||
if interval <= self.double_blink_interval:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def process_blink_detection(self):
|
||||
"""
|
||||
在缓冲区数据上执行,单次眨眼检测
|
||||
"""
|
||||
if len(self.fp1_buffer) < self.window_samples:
|
||||
return
|
||||
|
||||
fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
|
||||
fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
|
||||
# 计算FP1和FP2的平均
|
||||
fp12_mean = (fp1_data + fp2_data) / 2.0
|
||||
# 带通滤波
|
||||
try:
|
||||
fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
|
||||
except Exception as e:
|
||||
print(f"Filter error: {e}")
|
||||
return
|
||||
F = np.diff(fp12_filtered)
|
||||
if len(F) < 3:
|
||||
return
|
||||
b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax)
|
||||
|
||||
if b == 1:
|
||||
samples_since_last = self.total_samples - self.last_blink_time
|
||||
time_since_last_ms = (samples_since_last / self.fs) * 1000
|
||||
if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
|
||||
self.blink_count += 1
|
||||
self.last_blink_time = self.total_samples
|
||||
current_time = time.time()
|
||||
self.blink_timestamps.append(current_time)
|
||||
blink_event = {
|
||||
'count': self.blink_count,
|
||||
'time': current_time,
|
||||
'sample_index': self.total_samples,
|
||||
'duration_ms': d,
|
||||
'energy': e
|
||||
}
|
||||
self.blink_events.append(blink_event)
|
||||
self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
|
||||
if self.check_double_blink(current_time):
|
||||
self.double_blink_count += 1
|
||||
interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
|
||||
double_blink_event = {
|
||||
'double_blink_count': self.double_blink_count,
|
||||
'blink1_time': self.blink_timestamps[-2],
|
||||
'blink2_time': self.blink_timestamps[-1],
|
||||
'interval': interval
|
||||
}
|
||||
self.double_blink_events.append(double_blink_event)
|
||||
self.last_double_blink_time = current_time
|
||||
self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
|
||||
|
||||
def decoder_blink(self):
|
||||
if self.thread_data_server.GetDataLenCount() < 50:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
data = self.thread_data_server.get_blinkData(50)
|
||||
fp1_data = data[0, :] # ch1 (相当于FP1)
|
||||
fp2_data = data[1, :] # ch2 (相当于FP2)
|
||||
for i in range(len(fp1_data)):
|
||||
self.fp1_buffer.append(fp1_data[i])
|
||||
self.fp2_buffer.append(fp2_data[i])
|
||||
self.total_samples += 1
|
||||
self.sample_counter += 1
|
||||
|
||||
if self.sample_counter >= self.step_samples:
|
||||
self.process_blink_detection()
|
||||
self.sample_counter = 0
|
||||
|
||||
def stop(self):
|
||||
'''
|
||||
@@ -469,13 +610,12 @@ class Decoder_main(threading.Thread):
|
||||
@return:
|
||||
'''
|
||||
self.zmqServer.stop()
|
||||
self.sliding_filter.stop()
|
||||
self.Runing=False
|
||||
|
||||
def reset_state(self):
|
||||
"""清空解码器状态和缓存数据"""
|
||||
# 重置设备层缓存
|
||||
self.zmqServer.reset_state()
|
||||
self.thread_data_server.reset_state()
|
||||
|
||||
# 重置解码状态
|
||||
self.decodingSteps = 0
|
||||
|
||||
@@ -34,7 +34,7 @@ cudnn.benchmark = True
|
||||
cudnn.deterministic = True
|
||||
from sklearn.model_selection import train_test_split
|
||||
# writer = SummaryWriter('./TensorBoardX/')
|
||||
from logs.log import algo_log
|
||||
|
||||
|
||||
# Convolution module
|
||||
# use conv to capture local features, instead of postion embedding.
|
||||
@@ -82,7 +82,7 @@ class MultiHeadAttention(nn.Module):
|
||||
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||
if mask is not None:
|
||||
fill_value = torch.finfo(torch.float64).min
|
||||
fill_value = torch.finfo(torch.float32).min
|
||||
energy.mask_fill(~mask, fill_value)
|
||||
|
||||
scaling = self.emb_size ** (1 / 2)
|
||||
@@ -318,7 +318,11 @@ class ExP():
|
||||
train_pred = torch.max(outputs, 1)[1]
|
||||
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
||||
|
||||
algo_log(f"Epoch = {e}, Train loss = {loss.detach().cpu().numpy():.6f}, Test loss = {loss_test.detach().cpu().numpy():.6f}, Train accuracy = {train_acc:.6f}, Test accuracy = {acc:.6f}", level="debug")
|
||||
print('Epoch:', e,
|
||||
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
||||
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
||||
' Train accuracy %.6f' % train_acc,
|
||||
' Test accuracy is %.6f' % acc)
|
||||
|
||||
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
||||
num = num + 1
|
||||
@@ -331,8 +335,8 @@ class ExP():
|
||||
|
||||
torch.save(self.model, model_path)
|
||||
averAcc = averAcc / num
|
||||
algo_log(f"The average accuracy is: {averAcc}", level="debug")
|
||||
algo_log(f"The best accuracy is: {bestAcc}", level="debug")
|
||||
print('The average accuracy is:', averAcc)
|
||||
print('The best accuracy is:', bestAcc)
|
||||
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
||||
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
||||
|
||||
@@ -342,10 +346,10 @@ class ExP():
|
||||
|
||||
def onlineTrain(data_queue,result_queue):
|
||||
import torch
|
||||
algo_log(f"[DEBUG] torch.__version__ = {torch.__version__}", level="debug")
|
||||
algo_log(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}", level="debug")
|
||||
print(f"[DEBUG] torch.__version__ = {torch.__version__}")
|
||||
print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
algo_log(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}", level="debug")
|
||||
print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}")
|
||||
try:
|
||||
starttime = datetime.datetime.now()
|
||||
|
||||
@@ -362,13 +366,12 @@ def onlineTrain(data_queue,result_queue):
|
||||
data = data_queue.get(timeout=30)
|
||||
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
||||
exp = ExP(n_chan)
|
||||
algo_log(f"训练参数: {np.shape(all_data)}, {np.shape(all_label)}, {model_path}", level="debug")
|
||||
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
|
||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
||||
algo_log(f"THE BEST ACCURACY IS {str(bestAcc)}", level="debug")
|
||||
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||
|
||||
endtime = datetime.datetime.now()
|
||||
algo_log(f"train duration: {endtime - starttime}", level="debug")
|
||||
|
||||
print('train duration: ',str(endtime - starttime))
|
||||
|
||||
# 将模型或参数传回
|
||||
result_queue.put({
|
||||
@@ -384,7 +387,7 @@ def offlineTrain(all_data,all_label,modelPath):
|
||||
|
||||
# seed_n = np.random.randint(2025)
|
||||
seed_n = 1877
|
||||
algo_log(f"seed is {seed_n}", level="debug")
|
||||
print('seed is ' + str(seed_n))
|
||||
random.seed(seed_n)
|
||||
np.random.seed(seed_n)
|
||||
torch.manual_seed(seed_n)
|
||||
@@ -394,12 +397,13 @@ def offlineTrain(all_data,all_label,modelPath):
|
||||
exp = ExP()
|
||||
|
||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
||||
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
|
||||
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||
|
||||
endtime = datetime.datetime.now()
|
||||
algo_log(f"train duration: {endtime - starttime}", level="debug")
|
||||
print('train duration: ',str(endtime - starttime))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
algo_log(f"[DEBUG] time.asctime(time.localtime(time.time())) = {time.asctime(time.localtime(time.time()))}", level="debug")
|
||||
print(time.asctime(time.localtime(time.time())))
|
||||
print(time.asctime(time.localtime(time.time())))
|
||||
|
||||
@@ -22,7 +22,6 @@ from einops import rearrange
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
from torch.backends import cudnn
|
||||
from sklearn.model_selection import train_test_split
|
||||
from logs.log import algo_log
|
||||
# writer = SummaryWriter('./TensorBoardX/')
|
||||
|
||||
|
||||
@@ -72,7 +71,7 @@ class MultiHeadAttention(nn.Module):
|
||||
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||
if mask is not None:
|
||||
fill_value = torch.finfo(torch.float64).min
|
||||
fill_value = torch.finfo(torch.float32).min
|
||||
energy.mask_fill(~mask, fill_value)
|
||||
|
||||
scaling = self.emb_size ** (1 / 2)
|
||||
@@ -191,7 +190,7 @@ class ExP():
|
||||
# 自动选择设备:有 GPU 用 GPU,否则用 CPU
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# self.device = torch.device("cpu")
|
||||
algo_log(f"Using device: {self.device}", level="debug")
|
||||
print(f"Using device: {self.device}")
|
||||
|
||||
# 定义张量类型(不再强制使用 cuda)
|
||||
self.Tensor = torch.FloatTensor
|
||||
|
||||
28
README.md
28
README.md
@@ -13,29 +13,5 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and
|
||||
6. decoder class切换问题
|
||||
7. decoder_class切换时,数据重置、各类参数重置
|
||||
|
||||
# realease log
|
||||
- 2026年6月11日11:29:17 打包第一版,包名runDecoder.dist_v0.0.0_beta_20260611.7z
|
||||
- 2026年6月11日12:00:00 打包第二版,包名runDecoder.dist_v0.0.0_beta_20260611.7z
|
||||
- 修复上位机先发decoder_class, 后发open_impedence 带来decoder_main thread 阻塞问题
|
||||
|
||||
- 2026年6月12日15:05:47 runDecoder.dist_v0.0.2_beta_20260612
|
||||
- 优化filter读数精度
|
||||
|
||||
# 常用命令
|
||||
source activate 3in1Py310
|
||||
python runDecoder.py
|
||||
python datamock.py
|
||||
python ZeroMQClient_mock.py
|
||||
python filter_test.py
|
||||
python upperHost_stimmock/MI_headless.py
|
||||
|
||||
# 打包命令
|
||||
./nuitka_3in1_package.sh
|
||||
|
||||
# TODO
|
||||
1. mvep是否要把list freq 开放到config
|
||||
2. 滤波器参数 放到config文件
|
||||
|
||||
# debug log
|
||||
## MI
|
||||
Epoch采集完成|收到命令: {'method': 'train'|取出的
|
||||
# update
|
||||
2026年6月5日13:55:34
|
||||
@@ -12,17 +12,16 @@ from scipy.io import loadmat
|
||||
from scipy.linalg import qr
|
||||
from scipy.signal import filtfilt, lfilter
|
||||
# from numpy.linalg import _umath_linalg
|
||||
from logs.log import algo_log
|
||||
|
||||
|
||||
class FbccaDw:
|
||||
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
||||
algo_log('******************************************', level="debug")
|
||||
algo_log('parameter list',level="debug")
|
||||
algo_log(f"target: {num_target}", level="debug")
|
||||
algo_log(f"number of filter bank: {num_filter}", level="debug")
|
||||
algo_log(f"parameter: {parameter}", level="debug")
|
||||
algo_log(f"width: {width}", level="debug")
|
||||
print('******************************************')
|
||||
print('parameter list')
|
||||
print('target:', num_target)
|
||||
print('number of filter bank:', num_filter)
|
||||
print('parameter:', parameter)
|
||||
print('width:', width)
|
||||
self.phase = 0
|
||||
self.bandWidth = width
|
||||
self.winNum = winNum
|
||||
@@ -238,7 +237,7 @@ class FbccaDw:
|
||||
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
||||
return np.asmatrix(dataFiltered)
|
||||
except Exception:
|
||||
algo_log(f"Exception: {Exception}", level="debug")
|
||||
print(Exception)
|
||||
|
||||
'''
|
||||
getDataQ
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
import numpy as np
|
||||
from scipy.signal import welch
|
||||
from collections import deque
|
||||
|
||||
|
||||
|
||||
class Beta_Calculate():
|
||||
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=5, config=None):
|
||||
self.Threshold_value_low = Threshold_value_low
|
||||
self.Threshold_value_high = Threshold_value_high
|
||||
self.fs = fs
|
||||
self.beta_result = []
|
||||
self.eegQueue = deque(maxlen=win_len)
|
||||
|
||||
def calculate_all(self, data, fs, nperseg=1000):
|
||||
mean_x = np.mean(data, axis=-1, keepdims=True)
|
||||
data = data - mean_x
|
||||
freqs, psd = self.compute_psd_multichannel(data, fs, nperseg)
|
||||
beta_psd = np.sum(self.band_psd(freqs, psd, (13, 30)))
|
||||
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
|
||||
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
|
||||
|
||||
# print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
|
||||
|
||||
return beta_psd, alpha_psd, theta_psd
|
||||
|
||||
def compute_psd_multichannel(self, data, fs=250, nperseg=1000):
|
||||
n_samples = data.shape[-1]
|
||||
if n_samples < nperseg:
|
||||
nperseg = n_samples
|
||||
|
||||
noverlap = 500
|
||||
if noverlap >= nperseg:
|
||||
noverlap = int(nperseg / 2)
|
||||
|
||||
if nperseg == 0:
|
||||
return np.array([]), np.zeros((data.shape[0], 0))
|
||||
|
||||
freqs, psd = welch(data, fs=fs, nperseg=nperseg, noverlap=noverlap, axis=-1)
|
||||
return freqs, psd
|
||||
|
||||
def band_psd(self, freqs, psd, band):
|
||||
idx = np.logical_and(freqs >= band[0], freqs <= band[1])
|
||||
return np.sum(psd[:, idx], axis=-1)
|
||||
|
||||
|
||||
def reset_queue(self):
|
||||
self.eegQueue.clear()
|
||||
|
||||
|
||||
def queueOpt(self, data):
|
||||
if data is None or data.size == 0:
|
||||
return None
|
||||
if len(self.eegQueue) < self.eegQueue.maxlen:
|
||||
self.eegQueue.append(data)
|
||||
else:
|
||||
self.eegQueue.append(data)
|
||||
|
||||
if len(self.eegQueue) == self.eegQueue.maxlen:
|
||||
eegData = np.hstack([self.eegQueue[i] for i in range(len(self.eegQueue))])
|
||||
if eegData.size == 0:
|
||||
return None
|
||||
eegData -= np.mean(eegData, axis=-1, keepdims=True)
|
||||
|
||||
beta_psd, alpha_psd, theta_psd = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
|
||||
|
||||
return (beta_psd)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
import zmq
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
|
||||
def receive_messages(socket, stop_event):
|
||||
"""
|
||||
后台线程函数,用于持续接收服务器消息
|
||||
|
||||
Args:
|
||||
socket (zmq.Socket): ZeroMQ套接字
|
||||
stop_event (threading.Event): 停止事件,用于通知线程退出
|
||||
"""
|
||||
print("开始持续接收服务器数据...")
|
||||
print("-" * 50)
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
# 设置接收超时为1秒,避免阻塞
|
||||
socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||
# 接收服务器的消息
|
||||
frames = socket.recv_multipart()
|
||||
|
||||
# DEALER 套接字接收消息格式:[身份标识, 空帧, 消息内容]
|
||||
# 使用frames[-1]获取最后一帧,无论中间有多少空帧
|
||||
if len(frames) >= 2:
|
||||
message = frames[-1].decode('utf-8')
|
||||
|
||||
# 尝试解析为JSON格式
|
||||
try:
|
||||
json_message = json.loads(message)
|
||||
# 检查消息长度
|
||||
json_str = str(json_message)
|
||||
if len(json_str) > 100:
|
||||
print(f"收到服务器数据 (JSON): {json_str[:100]}...")
|
||||
else:
|
||||
print(f"收到服务器数据 (JSON): {json_message}")
|
||||
except json.JSONDecodeError:
|
||||
# 检查消息长度
|
||||
if len(message) > 100:
|
||||
print(f"收到服务器数据 (原始): {message[:100]}...")
|
||||
else:
|
||||
print(f"收到服务器数据 (原始): {message}")
|
||||
else:
|
||||
print(f"收到服务器数据 (格式异常): {frames}")
|
||||
|
||||
except zmq.Again:
|
||||
# 接收超时,继续循环
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"接收消息时发生错误: {e}")
|
||||
# 短暂暂停后继续接收
|
||||
time.sleep(1)
|
||||
|
||||
print("接收线程已停止。")
|
||||
|
||||
def zero_mq_client(server_address="tcp://127.0.0.1:8099"):
|
||||
"""
|
||||
ZeroMQ客户端函数,用于与服务器通信
|
||||
|
||||
Args:
|
||||
server_address (str): 服务器地址,格式为"tcp://IP:端口"
|
||||
"""
|
||||
# 创建 ZeroMQ 上下文
|
||||
context = zmq.Context()
|
||||
|
||||
# 创建 DEALER 套接字
|
||||
socket = context.socket(zmq.DEALER)
|
||||
|
||||
# 生成唯一的身份标识
|
||||
identity = str('wdd').encode('utf-8')
|
||||
socket.setsockopt(zmq.IDENTITY, identity)
|
||||
|
||||
try:
|
||||
# 连接到服务器
|
||||
print(f"连接到服务器 {server_address}...")
|
||||
socket.connect(server_address)
|
||||
|
||||
# 定义消息集
|
||||
message_set = [
|
||||
{"method": "sync", "params": 1},
|
||||
{"method": "decoderClass", "params": "mi"},
|
||||
{"method": "decoderClass", "params": "ssvep"},
|
||||
{"method": "decoderClass", "params": "ssmvep"},
|
||||
{"method": "decoderClass", "params": "blink"},
|
||||
{"method": "decoderClass", "params": "concentration"},
|
||||
{"method": "train", "params": 0},
|
||||
{"method": "train", "params": 1},
|
||||
{"method": "rest", "params": 0},
|
||||
{"method": "predict", "params": 1},
|
||||
{"method": "getReport", "params": 0},
|
||||
{"method": "targetFreqs", "params": [11, 12, 13]}
|
||||
]
|
||||
|
||||
# 打印消息集
|
||||
print("消息集:")
|
||||
for i, msg in enumerate(message_set):
|
||||
print(f"[{i}] {msg}")
|
||||
print("-" * 50)
|
||||
|
||||
# 创建停止事件
|
||||
stop_event = threading.Event()
|
||||
|
||||
# 启动接收线程
|
||||
receive_thread = threading.Thread(target=receive_messages, args=(socket, stop_event))
|
||||
receive_thread.daemon = True # 设置为守护线程,主线程退出时自动退出
|
||||
receive_thread.start()
|
||||
|
||||
# 主线程处理控制台输入
|
||||
print("输入消息序号发送对应消息,输入'q'退出程序:")
|
||||
while True:
|
||||
try:
|
||||
# 获取用户输入
|
||||
user_input = input("请输入消息序号: ")
|
||||
|
||||
# 检查是否退出
|
||||
if user_input.lower() == 'q':
|
||||
print("正在退出程序...")
|
||||
break
|
||||
|
||||
# 尝试转换为整数
|
||||
msg_index = int(user_input)
|
||||
|
||||
# 检查序号是否有效
|
||||
if 0 <= msg_index < len(message_set):
|
||||
# 获取对应的消息
|
||||
selected_message = message_set[msg_index]
|
||||
|
||||
# 将消息转换为 JSON 字符串
|
||||
json_message = json.dumps(selected_message)
|
||||
|
||||
# 打印发送信息
|
||||
print(f"\n发送消息 (大小: {len(json_message)} 字节)...")
|
||||
print(f"消息方法: {selected_message['method']}")
|
||||
print(f"参数值: {selected_message['params']}")
|
||||
|
||||
# DEALER 套接字发送消息,包含身份标识和空帧
|
||||
socket.send_multipart([identity, json_message.encode('utf-8')])
|
||||
print("消息发送完成!")
|
||||
print("-" * 50)
|
||||
else:
|
||||
print(f"无效的消息序号,请输入 0-{len(message_set)-1} 之间的数字。")
|
||||
print("消息集:")
|
||||
for i, msg in enumerate(message_set):
|
||||
print(f"[{i}] {msg}")
|
||||
print("-" * 50)
|
||||
|
||||
except ValueError:
|
||||
print("请输入有效的数字或'q'退出。")
|
||||
except Exception as e:
|
||||
print(f"处理输入时发生错误: {e}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n程序被手动终止。")
|
||||
finally:
|
||||
# 停止接收线程
|
||||
stop_event.set()
|
||||
# 等待接收线程停止
|
||||
time.sleep(1)
|
||||
# 关闭套接字和上下文
|
||||
socket.close()
|
||||
context.term()
|
||||
print("客户端已关闭。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
zero_mq_client()
|
||||
@@ -5,13 +5,12 @@
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
import threading
|
||||
from logs.log import algo_log
|
||||
|
||||
class ParadigmRingBuffer:
|
||||
def __init__(self, n_chan, n_points):
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||
self.buffer = np.zeros((n_chan, n_points))
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0
|
||||
self.nUpdate = 0
|
||||
@@ -20,8 +19,7 @@ class ParadigmRingBuffer:
|
||||
## append buffer and update current pointer
|
||||
def appendBuffer(self, data):
|
||||
if self.nUpdate == self.n_points:
|
||||
# raise Exception("Buffer is full")
|
||||
algo_log("ParadigmRingBuffer is full", record_once=True)
|
||||
raise Exception("Buffer is full")
|
||||
|
||||
n = data.shape[1]
|
||||
|
||||
@@ -67,56 +65,13 @@ class ParadigmRingBuffer:
|
||||
'''
|
||||
return self.nUpdate
|
||||
|
||||
# ========== 各范式数据访问接口 ==========
|
||||
def get_MIData(self):
|
||||
"""获取MI导联数据 (21通道 + 事件)"""
|
||||
data = self.getData(self.GetDataLenCount())
|
||||
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_SSMVEPData(self):
|
||||
"""获取SSMVEP导联数据 (8通道 + 事件)"""
|
||||
data = self.getData(self.GetDataLenCount())
|
||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def getDataViaSSVEP(self, count):
|
||||
"""获取SSVEP数据 (8通道 + 事件)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_concentrateData(self, count):
|
||||
"""获取专注力数据 (2通道)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [0, 1]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_blinkData(self, count):
|
||||
"""获取眨眼数据 (2通道)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [0, 1]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
# reset buffer
|
||||
def resetAllPara(self):
|
||||
self.nUpdate = 0
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0
|
||||
self.buffer.fill(0.0)
|
||||
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,310 +3,206 @@
|
||||
数据滤波模块
|
||||
"""
|
||||
import numpy as np
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
from scipy import signal
|
||||
from logs.log import algo_log
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from Tools.beta_calculate import Beta_Calculate
|
||||
|
||||
class FilterRingBuffer:
|
||||
def __init__(self, n_chan, n_points):
|
||||
"""
|
||||
初始化纯数据环形缓存
|
||||
:param n_chan: 通道数
|
||||
:param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致)
|
||||
"""
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||
self.current_ptr = 0
|
||||
self.total_samples = 0
|
||||
self.lock = threading.Lock() # 仅保护元数据
|
||||
self.has_new_data = False
|
||||
self.current_ptr = 0 # 写入指针
|
||||
self.total_samples = 0 # 已写入总点数
|
||||
|
||||
# 线程安全锁(多线程环境必须)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def appendBuffer(self, data):
|
||||
"""
|
||||
追加数据到缓存(与paradigmRingBuffer接口一致)
|
||||
:param data: 输入数据,shape=(n_chan, n_samples)
|
||||
"""
|
||||
with self.lock:
|
||||
n = data.shape[1]
|
||||
if n == 0:
|
||||
return
|
||||
|
||||
# 仅加锁读取/更新元数据
|
||||
with self.lock:
|
||||
old_ptr = self.current_ptr
|
||||
new_ptr = (old_ptr + n) % self.n_points
|
||||
new_total = min(self.total_samples + n, self.n_points)
|
||||
self.has_new_data = True
|
||||
|
||||
# 数组写入(耗时操作,移出锁外)
|
||||
write_end = old_ptr + n
|
||||
# 环形写入逻辑
|
||||
write_end = self.current_ptr + n
|
||||
if write_end <= self.n_points:
|
||||
self.buffer[:, old_ptr:write_end] = data
|
||||
self.buffer[:, self.current_ptr:write_end] = data
|
||||
else:
|
||||
split = self.n_points - old_ptr
|
||||
self.buffer[:, old_ptr:] = data[:, :split]
|
||||
split = self.n_points - self.current_ptr
|
||||
self.buffer[:, self.current_ptr:] = data[:, :split]
|
||||
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
||||
|
||||
# 再次加锁更新最终元数据
|
||||
with self.lock:
|
||||
self.current_ptr = new_ptr
|
||||
self.total_samples = new_total
|
||||
|
||||
# ========== 新增:获取&清空新数据标记的方法 ==========
|
||||
def check_and_clear_new_data(self):
|
||||
"""检查是否有新数据,并一次性清空标记(消费后重置)"""
|
||||
with self.lock:
|
||||
flag = self.has_new_data
|
||||
if flag:
|
||||
self.has_new_data = False
|
||||
return flag
|
||||
# 更新指针和计数
|
||||
self.current_ptr = write_end % self.n_points
|
||||
self.total_samples = min(self.total_samples + n, self.n_points)
|
||||
|
||||
def getData(self, count):
|
||||
# 加锁获取最新元数据
|
||||
"""
|
||||
从读指针位置读取count个点(与paradigmRingBuffer接口一致)
|
||||
:param count: 读取点数
|
||||
:return: np.ndarray, shape=(n_chan, count)
|
||||
"""
|
||||
with self.lock:
|
||||
count = min(count, self.total_samples)
|
||||
if count == 0:
|
||||
return np.zeros((self.n_chan, 0))
|
||||
|
||||
# 环形读取逻辑(与paradigmRingBuffer完全相同)
|
||||
end = self.current_ptr
|
||||
start = end - count
|
||||
|
||||
# 数据读取、切片、拼接(无锁)
|
||||
if start >= 0:
|
||||
res = self.buffer[:, start:end].copy()
|
||||
return self.buffer[:, start:end].copy()
|
||||
else:
|
||||
part1 = self.buffer[:, start:]
|
||||
part2 = self.buffer[:, :end]
|
||||
res = np.concatenate((part1, part2), axis=1).copy()
|
||||
return res
|
||||
return np.concatenate((part1, part2), axis=1)
|
||||
|
||||
def get_latest_n_points(self, n):
|
||||
"""
|
||||
扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口)
|
||||
:param n: 点数
|
||||
:return: np.ndarray, shape=(n_chan, n)
|
||||
"""
|
||||
with self.lock:
|
||||
if self.total_samples < n:
|
||||
return None
|
||||
return self.getData(n)
|
||||
|
||||
def GetDataLenCount(self):
|
||||
"""获取当前缓存总点数(兼容原有接口)"""
|
||||
with self.lock:
|
||||
return self.total_samples
|
||||
|
||||
def resetAllPara(self):
|
||||
"""重置所有缓存和指针(兼容原有接口)"""
|
||||
with self.lock:
|
||||
self.buffer.fill(0.0)
|
||||
self.current_ptr = 0
|
||||
self.total_samples = 0
|
||||
self.has_new_data = False # 重置时清空新数据标记
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2. 独立 Beta PSD 计算线程(避免阻塞滤波主循环的 200ms 定时)
|
||||
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||||
# 可替换任意缓存实现,只要实现appendBuffer、get_latest_n_points接口
|
||||
# -----------------------------------------------------------------------------
|
||||
class BetaPsdCalculator(threading.Thread):
|
||||
"""独立的 Beta PSD 计算线程,使用队列与滤波主线程解耦"""
|
||||
|
||||
def __init__(self, fs=250, window_size=750):
|
||||
super().__init__(daemon=True)
|
||||
self.fs = fs
|
||||
self.window_size = window_size
|
||||
self._beta_calc = Beta_Calculate(Threshold_value_low=0, Threshold_value_high=0, fs=fs)
|
||||
self._input_queue = queue.Queue(maxsize=2)
|
||||
self._running = threading.Event()
|
||||
self._running.set()
|
||||
self._latest_beta = None
|
||||
self._beta_lock = threading.Lock()
|
||||
self.beta_broadcast_callback = None
|
||||
|
||||
def push_data(self, data):
|
||||
"""供外部调用的线程安全数据推送接口"""
|
||||
try:
|
||||
self._input_queue.put_nowait(data)
|
||||
except queue.Full:
|
||||
try:
|
||||
self._input_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
try:
|
||||
self._input_queue.put_nowait(data)
|
||||
except queue.Full:
|
||||
pass
|
||||
|
||||
def get_latest_beta(self):
|
||||
"""获取最新的 beta 值(线程安全)"""
|
||||
with self._beta_lock:
|
||||
return self._latest_beta
|
||||
|
||||
def run(self):
|
||||
while self._running.is_set():
|
||||
try:
|
||||
data = self._input_queue.get(timeout=1.5)
|
||||
if data is None:
|
||||
break
|
||||
try:
|
||||
beta_psd, _, _ = self._beta_calc.calculate_all(
|
||||
data, fs=self.fs, nperseg=min(self.window_size, data.shape[1])
|
||||
)
|
||||
with self._beta_lock:
|
||||
self._latest_beta = round(float(beta_psd), 3)
|
||||
if self.beta_broadcast_callback is not None:
|
||||
self.beta_broadcast_callback(self._latest_beta)
|
||||
except Exception as e:
|
||||
algo_log(f"Beta PSD 计算异常: {e}", level='error')
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""停止计算线程"""
|
||||
self._running.clear()
|
||||
try:
|
||||
self._input_queue.put_nowait(None)
|
||||
except queue.Full:
|
||||
pass
|
||||
if self.is_alive():
|
||||
self.join(timeout=2)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 3. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||||
# -----------------------------------------------------------------------------
|
||||
class SlidingFilter(threading.Thread):
|
||||
class SlidingFilter:
|
||||
def __init__(
|
||||
self,
|
||||
ring_buffer: FilterRingBuffer,
|
||||
n_chan=66,
|
||||
srate=250,
|
||||
buffer_sec=5,
|
||||
window_sec=3,
|
||||
step_sec=0.2
|
||||
step_sec=0.2,
|
||||
packet_size=5
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
"""
|
||||
初始化滑动滤波器
|
||||
:param n_chan: 通道数
|
||||
:param srate: 采样率
|
||||
:param buffer_sec: 总缓存时长(秒)
|
||||
:param window_sec: 滤波窗口时长(秒)
|
||||
:param step_sec: 滑动步长/输出时长(秒)
|
||||
:param packet_size: 每包数据点数(20ms一包=5点)
|
||||
"""
|
||||
# 核心参数
|
||||
self.n_chan = n_chan
|
||||
self.srate = srate
|
||||
self.step_sec = step_sec # 200ms滑动步长
|
||||
self.window_sec = window_sec # 3秒窗口
|
||||
self.step_sec = step_sec # 200ms滑动步长
|
||||
self.window_size = int(srate * window_sec) # 3秒点数:250*3=750
|
||||
self.step_size = int(srate * step_sec) # 200ms点数:250*0.2=50
|
||||
self.buffer_size = int(srate * buffer_sec)
|
||||
self.window_size = int(srate * window_sec)
|
||||
self.step_size = int(srate * step_sec)
|
||||
self.packet_size = packet_size
|
||||
|
||||
# 关联ZMQServer的环形缓存(解耦:仅依赖接口)
|
||||
self.ring_buffer = ring_buffer
|
||||
# 线程控制
|
||||
self.running = threading.Event()
|
||||
self.running.set()
|
||||
# 滤波结果回调(外部可注册,获取滤波后的数据)
|
||||
self.filter_result_callback = None
|
||||
# 初始化纯数据缓存(解耦核心)
|
||||
self.buffer = FilterRingBuffer(n_chan, self.buffer_size)
|
||||
|
||||
# beta 每秒触发计数(200ms步长,5次 = 1s)
|
||||
self._beta_step_counter = 0
|
||||
self._beta_steps_per_second = max(1, int(round(1.0 / step_sec))) # 5
|
||||
# 滤波触发计数器
|
||||
self.packet_count = 0
|
||||
self.ready_to_filter = False
|
||||
|
||||
self.slide_window = None # 滑动窗口缓存 (n_chan, window_size)
|
||||
self.slide_ready = False # 窗口是否已填满初始数据
|
||||
# 预计算滤波器系数(仅执行一次)
|
||||
# 预计算滤波器系数
|
||||
self._init_filters()
|
||||
|
||||
# 独立的 Beta 计算线程(避免阻塞滤波主循环)
|
||||
self._beta_thread = BetaPsdCalculator(fs=srate, window_size=self.window_size)
|
||||
|
||||
def start(self):
|
||||
"""同时启动 Beta 计算线程和滤波主线程"""
|
||||
self._beta_thread.start()
|
||||
super().start()
|
||||
|
||||
def set_beta_broadcast_callback(self, callback):
|
||||
"""注册 Beta PSD 广播回调函数"""
|
||||
self._beta_thread.beta_broadcast_callback = callback
|
||||
|
||||
def _init_filters(self):
|
||||
"""预计算所有滤波器系数(仅执行一次)"""
|
||||
# 50Hz工频陷波(Q=30,工业标准)
|
||||
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
|
||||
# 0.5~45Hz带通FIR(65阶,线性相位)
|
||||
# 8~30Hz带通FIR(65阶,线性相位)
|
||||
self.b_bp = signal.firwin(
|
||||
numtaps=65,
|
||||
cutoff=[0.5/(self.srate/2), 45/(self.srate/2)],
|
||||
cutoff=[8/(self.srate/2), 30/(self.srate/2)],
|
||||
pass_zero=False,
|
||||
window='hamming'
|
||||
)
|
||||
self.a_bp = np.array([1.0])
|
||||
|
||||
def _filter_window_data(self, window_data):
|
||||
"""对3秒窗口数据执行滤波,返回 (无边界效应的200ms数据, 完整3s滤波数据)"""
|
||||
def append_and_check_trigger(self, raw_data):
|
||||
"""
|
||||
追加单包原始数据并检查是否触发滤波
|
||||
:param raw_data: 上位机原始数据,shape=(packet_size, n_chan)
|
||||
:return: bool: 是否触发本次滤波
|
||||
"""
|
||||
# 转置为标准格式:(通道数, 点数)
|
||||
data = raw_data.T.astype(np.float64)
|
||||
|
||||
# 写入缓存(纯缓存操作)
|
||||
self.buffer.appendBuffer(data)
|
||||
|
||||
# 更新包计数器
|
||||
self.packet_count += 1
|
||||
|
||||
# 检查滤波触发条件:数据≥窗口长度 且 累计满一个步长的包数
|
||||
packets_per_step = int(self.step_size / self.packet_size) # 10包=200ms
|
||||
if (self.buffer.GetDataLenCount() >= self.window_size
|
||||
and self.packet_count >= packets_per_step):
|
||||
self.packet_count = 0
|
||||
self.ready_to_filter = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def filter_and_get_output(self):
|
||||
"""
|
||||
执行滤波并返回无边界效应的输出数据
|
||||
:return: np.ndarray: 滤波后数据,shape=(n_chan, step_size)
|
||||
"""
|
||||
if not self.ready_to_filter:
|
||||
return None
|
||||
|
||||
# 获取最新的完整滤波窗口数据
|
||||
window_data = self.buffer.get_latest_n_points(self.window_size)
|
||||
if window_data is None:
|
||||
self.ready_to_filter = False
|
||||
return None
|
||||
|
||||
# 零相位滤波(无延迟,无边界效应)
|
||||
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
||||
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
||||
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
|
||||
|
||||
# 提取倒数第二个200ms的数据(完全避开两端边界效应)
|
||||
# 窗口长度750,步长50 → start=750-100=650,end=750-50=700
|
||||
# 提取倒数第二个步长的数据(完全避开两端边界效应)
|
||||
start_idx = self.window_size - 2 * self.step_size
|
||||
end_idx = self.window_size - self.step_size
|
||||
output_data = filtered[:, start_idx:end_idx].copy()
|
||||
return output_data, filtered
|
||||
|
||||
def run(self):
|
||||
"""线程主逻辑:精确200ms触发一次滤波"""
|
||||
interval = self.step_sec # 0.2s
|
||||
# 以启动时刻为绝对时间基准(核心改动)
|
||||
base_time = time.perf_counter()
|
||||
frame_count = 0 # 帧计数器,用于对齐时序
|
||||
# 重置触发标志
|
||||
self.ready_to_filter = False
|
||||
|
||||
while self.running.is_set():
|
||||
# 计算理论执行时刻:严格按帧序号 × 步长
|
||||
expect_time = base_time + frame_count * interval
|
||||
current_time = time.perf_counter()
|
||||
return output_data
|
||||
|
||||
# 精确定时等待
|
||||
if current_time < expect_time:
|
||||
time.sleep(expect_time - current_time)
|
||||
else:
|
||||
# 处理超时:仅告警,不重置基准(防止累积偏移)
|
||||
algo_log(f"滤波任务超时,偏移 {(current_time - expect_time)*1000:.1f} ms", level='debug')
|
||||
def reset(self):
|
||||
"""重置滤波器和缓存"""
|
||||
self.buffer.resetAllPara()
|
||||
self.packet_count = 0
|
||||
self.ready_to_filter = False
|
||||
|
||||
frame_count += 1 # 帧序号自增,保证周期绝对稳定
|
||||
if not self.ring_buffer.check_and_clear_new_data():
|
||||
# 无新数据,不执行滤波、不发送数据
|
||||
continue
|
||||
|
||||
# ========== 原有滤波逻辑 ==========
|
||||
try:
|
||||
if not self.slide_ready:
|
||||
# 阶段1:首次填满3s初始窗口
|
||||
full_data = self.ring_buffer.get_latest_n_points(self.window_size)
|
||||
if full_data is None:
|
||||
algo_log("初始窗口数据不足", level='debug')
|
||||
continue
|
||||
self.slide_window = full_data
|
||||
self.slide_ready = True
|
||||
else:
|
||||
# 阶段2:正常滑动 → 取最新50个新点,增量拼接
|
||||
new_step_data = self.ring_buffer.get_latest_n_points(self.step_size)
|
||||
if new_step_data is None:
|
||||
algo_log("滑动步长数据不足", level='debug')
|
||||
continue
|
||||
# 增量滑动:丢弃前50点,拼接新50点(标准滑动窗口)
|
||||
self.slide_window = np.hstack([
|
||||
self.slide_window[:, self.step_size:],
|
||||
new_step_data
|
||||
])
|
||||
|
||||
filtered_data, filtered_full = self._filter_window_data(self.slide_window[:64, :])
|
||||
|
||||
# Beta PSD 每秒计算一次
|
||||
self._beta_step_counter += 1
|
||||
if self._beta_step_counter >= self._beta_steps_per_second:
|
||||
self._beta_step_counter = 0
|
||||
self._beta_thread.push_data(filtered_full[:2, :])
|
||||
|
||||
if self.filter_result_callback is not None:
|
||||
self.filter_result_callback(filtered_data)
|
||||
except Exception as e:
|
||||
algo_log(f"滤波执行异常: {e}", level='error')
|
||||
|
||||
def set_result_callback(self, callback):
|
||||
"""注册滤波结果回调函数"""
|
||||
self.filter_result_callback = callback
|
||||
|
||||
def stop(self):
|
||||
"""停止滤波线程和 Beta 计算线程"""
|
||||
self._beta_thread.stop()
|
||||
self.running.clear()
|
||||
if self.is_alive():
|
||||
self.join(timeout=1)
|
||||
if self.is_alive():
|
||||
algo_log("警告:滤波线程在1秒内未正常退出,可能存在阻塞操作", level="WARNING")
|
||||
algo_log("滤波线程已停止")
|
||||
def get_buffer_length(self):
|
||||
"""获取当前缓存数据长度"""
|
||||
return self.buffer.GetDataLenCount()
|
||||
484
Zmq/zmqServer.py
484
Zmq/zmqServer.py
@@ -1,449 +1,241 @@
|
||||
# -*-coding:utf-8 -*-
|
||||
import ast
|
||||
import numpy as np
|
||||
import threading
|
||||
import zmq
|
||||
import threading
|
||||
import json
|
||||
import queue
|
||||
from typing import Dict
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from Zmq.dataBuffer import ParadigmRingBuffer
|
||||
from Zmq.filterProcess import FilterRingBuffer
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
# from Device.SunnyLinker import SunnyLinker64
|
||||
from dataBuffer import ParadigmRingBuffer
|
||||
from filterProcess import FilterRingBuffer
|
||||
from logs.log import algo_log
|
||||
|
||||
zmqServer_host = str(IniRead('system', 'zmqServer_host', '127.0.0.1'))
|
||||
|
||||
class zmqServer(threading.Thread):
|
||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
||||
threading.Thread.__init__(self)
|
||||
self.device_info = device_info
|
||||
|
||||
self.host = zmqServer_host
|
||||
|
||||
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
||||
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
||||
self.host = host
|
||||
self.cmd_port = cmd_port # 命令交互端口
|
||||
self.data_port = data_port # 数据接收端口
|
||||
self.running = False
|
||||
|
||||
# 原有业务状态变量
|
||||
self.open_Impedance = False #当前系统处于阻抗检测状态
|
||||
self.StartDecode = False
|
||||
self.StartTrain = False
|
||||
self.state_mode = None
|
||||
self.currentLabel = -1
|
||||
self.IsExitApp = False
|
||||
# self.get_Impedance = False # 是否返回阻抗值
|
||||
# self.open_Impedance = None # 是否开启阻抗检测功能
|
||||
self.StartDecode = False # false 停止解码,true=开始解码
|
||||
self.StartTrain = False # False未进入训练状态,True处于训练状态
|
||||
self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态
|
||||
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
|
||||
self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。
|
||||
# self.getReport = False # 获取训练报告内容
|
||||
self.daemon = True
|
||||
|
||||
# 双环形缓冲区
|
||||
self.paradigmBuffer = ParadigmRingBuffer(
|
||||
self.device_info['channel_nums'],
|
||||
self.device_info['sample_rate'] * 10
|
||||
)
|
||||
self.filterBuffer = FilterRingBuffer(
|
||||
self.device_info['channel_nums'],
|
||||
self.device_info['sample_rate'] * 10
|
||||
)
|
||||
self.paradigmBufferLock = threading.Lock()
|
||||
self.filterBufferLock = threading.Lock()
|
||||
# 范式数据缓存
|
||||
self.paradigmBuffer = ParadigmRingBuffer(66, 2500)
|
||||
self.filterBuffer = FilterRingBuffer(66, 2500)
|
||||
|
||||
# ZMQ上下文与套接字
|
||||
|
||||
# 命令与数据通信
|
||||
self.context = zmq.Context()
|
||||
|
||||
# 8099命令端口:ROUTER
|
||||
# 指令通道 (8099) - ROUTER:短JSON命令,低频率
|
||||
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
||||
self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100)
|
||||
self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100)
|
||||
self.cmd_socket.setsockopt(zmq.RCVHWM, 100) # 指令不需要大缓存,100条足够
|
||||
self.cmd_socket.setsockopt(zmq.SNDHWM, 100)
|
||||
self.cmd_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,降低指令延迟
|
||||
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
||||
|
||||
# 8100数据端口:ROUTER
|
||||
# 数据通道 (8100) - ROUTER:高频脑电二进制流
|
||||
self.data_socket = self.context.socket(zmq.ROUTER)
|
||||
self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500)
|
||||
self.data_socket.setsockopt(zmq.SocketOption.SNDHWM, 100) # 添加发送高水位线
|
||||
self.data_socket.setsockopt(zmq.RCVHWM, 500) # 500包=10秒缓存,足够应对短时卡顿
|
||||
self.data_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,减少数据传输延迟
|
||||
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
||||
|
||||
# Poller轮询器
|
||||
# Poller 轮训器(保持不变)
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.cmd_socket, zmq.POLLIN)
|
||||
self.poller.register(self.data_socket, zmq.POLLIN)
|
||||
|
||||
# 业务变量
|
||||
self.targetFreqs = []
|
||||
self.changeTarget = False
|
||||
self.changeTarget = False # 更换目标频率
|
||||
# self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化
|
||||
self.labels = [0x01, 0x02,0x03]
|
||||
self.decoder_switch = False
|
||||
self.decoder_class = None
|
||||
self.decoder_switch = False #更换解码器
|
||||
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
|
||||
|
||||
# 客户端管理(单客户端场景)
|
||||
self.cmd_clients = set()
|
||||
self.data_clients = set()
|
||||
self.current_data_client = None # 唯一数据客户端身份,用于发送滤波结果
|
||||
# 客户端管理 - 区分命令/数据客户端
|
||||
self.cmd_clients = set() # 命令端口客户端ID
|
||||
self.data_clients = set() # 数据端口客户端ID
|
||||
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
|
||||
|
||||
# 发送队列(双端口分离)
|
||||
self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列
|
||||
self.data_send_queue = queue.Queue() # 8100端口滤波数据队列
|
||||
|
||||
# 范式buffer与事件检测参数
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.latency = 50
|
||||
self.train_latency = 50
|
||||
self.count_events = {}
|
||||
self.epoch_finished = False
|
||||
self.pack_contain_event = False
|
||||
self.event_inner_idx = -1
|
||||
self.interval_inited = False
|
||||
self.last_epoch_finish_time = None
|
||||
|
||||
def reset_state(self):
|
||||
"""清空采集器状态和缓存数据"""
|
||||
with self.paradigmBufferLock:
|
||||
self.paradigmBuffer.resetAllPara()
|
||||
self.count_events = {}
|
||||
self.epoch_finished = False
|
||||
self.pack_contain_event = False
|
||||
self.event_inner_idx = -1
|
||||
self.interval_inited = False
|
||||
|
||||
def interval_init(self, decoder_class):
|
||||
if decoder_class == 'ssmvep':
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
|
||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # [50, 550]
|
||||
self.train_epoch = [
|
||||
int(self.interval_epoch[0]),
|
||||
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])
|
||||
] # [50, 575]
|
||||
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #115包, 575个点
|
||||
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点
|
||||
|
||||
elif decoder_class == 'mi':
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] #[125, 1125]
|
||||
self.train_epoch = self.interval_epoch.copy()
|
||||
self.latency = self.interval_epoch[1] // 5 #225
|
||||
self.train_latency = self.latency #225
|
||||
|
||||
algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO")
|
||||
self.count_events: Dict[str, int] = {}
|
||||
self.event_inner_idx = -1
|
||||
self.epoch_finished = False
|
||||
self.pack_contain_event = False
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.interval_inited = True
|
||||
|
||||
# -------------------------- 8099端口:命令结果广播 --------------------------
|
||||
def broadcast_message(self, method, params):
|
||||
"""
|
||||
向所有8099端口客户端广播JSON格式的命令结果
|
||||
用于:解码结果、训练状态、错误提示、进度通知等
|
||||
"""
|
||||
self.cmd_send_queue.put((method, params))
|
||||
"""Put message into queue to be sent to all command clients"""
|
||||
self.send_queue.put((method, params))
|
||||
|
||||
def _process_cmd_send_queue(self):
|
||||
"""处理8099端口发送队列,在主线程执行(保证ZMQ线程安全)"""
|
||||
while not self.cmd_send_queue.empty():
|
||||
method, params = self.cmd_send_queue.get()
|
||||
if not self.cmd_clients:
|
||||
continue
|
||||
|
||||
try:
|
||||
msg = {'method': method, 'params': params}
|
||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||
|
||||
if msg['method'] != 'beta_psd':
|
||||
algo_log(f"发送命令结果: {msg}", level="DEBUG")
|
||||
|
||||
# 广播到所有命令客户端
|
||||
for client_id in list(self.cmd_clients):
|
||||
try:
|
||||
self.cmd_socket.send_multipart([client_id, b"", msg_bytes])
|
||||
except Exception as e:
|
||||
algo_log(f"向命令客户端{client_id}发送失败: {e}", level="ERROR")
|
||||
self.cmd_clients.discard(client_id)
|
||||
|
||||
except Exception as e:
|
||||
algo_log(f"命令结果打包失败: {e}", level="ERROR")
|
||||
|
||||
# -------------------------- 8100端口:滤波结果发送 --------------------------
|
||||
def send_filtered_data(self, filtered_data):
|
||||
"""
|
||||
向8100端口客户端发送二进制格式的滤波结果
|
||||
用于:上位机实时绘图的脑电波形数据
|
||||
:param filtered_data: 滤波后数据,shape=(通道数, 50),float64格式
|
||||
"""
|
||||
if self.current_data_client is None:
|
||||
algo_log("数据客户端未连接,跳过滤波数据发送", level="WARNING")
|
||||
return
|
||||
|
||||
# 转置为上位机需要的[50, 通道数]格式
|
||||
filtered_data = filtered_data.T.astype(np.float64)
|
||||
send_buf = filtered_data.tobytes()
|
||||
# algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
|
||||
self.data_send_queue.put(send_buf)
|
||||
|
||||
def _process_data_send_queue(self):
|
||||
"""处理8100端口发送队列,在主线程执行(保证ZMQ线程安全)"""
|
||||
while not self.data_send_queue.empty():
|
||||
send_buf = self.data_send_queue.get()
|
||||
if self.current_data_client is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 标准ROUTER发送格式:[客户端ID, 空分隔帧, 数据帧]
|
||||
self.data_socket.send_multipart([
|
||||
self.current_data_client,
|
||||
b"",
|
||||
send_buf
|
||||
])
|
||||
algo_log(f"发送滤波数据成功,长度: {len(send_buf)}字节", level="DEBUG", record_once=True)
|
||||
|
||||
except Exception as e:
|
||||
algo_log(f"发送滤波数据失败: {e}", level="ERROR")
|
||||
# 客户端断开,重置身份
|
||||
self.current_data_client = None
|
||||
self.data_clients.clear()
|
||||
|
||||
# -------------------------- 命令端口消息处理 --------------------------
|
||||
def _handle_cmd_message(self, frames):
|
||||
"""处理8099端口JSON命令消息"""
|
||||
"""处理命令端口消息(原有命令交互逻辑)"""
|
||||
if len(frames) < 3:
|
||||
algo_log(f"无效命令帧:长度不足3帧,实际{len(frames)}", level="ERROR")
|
||||
return
|
||||
|
||||
ident, _, message_bytes = frames[:3]
|
||||
|
||||
# 注册新的命令客户端
|
||||
if ident not in self.cmd_clients:
|
||||
self.cmd_clients.add(ident)
|
||||
algo_log(f"新命令客户端连接成功: {ident}", level="INFO")
|
||||
print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
|
||||
|
||||
# 解析JSON命令
|
||||
# 解析消息
|
||||
try:
|
||||
message = json.loads(message_bytes.decode('utf-8'))
|
||||
except json.JSONDecodeError:
|
||||
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
|
||||
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
|
||||
return
|
||||
except Exception as e:
|
||||
algo_log(f"_handle_cmd_message exception: {e}", level="ERROR")
|
||||
return
|
||||
print(f"Invalid JSON from CMD client {ident}")
|
||||
continue
|
||||
print(f"Received CMD request: {message}")
|
||||
|
||||
algo_log(f"收到命令: {message}", level="INFO")
|
||||
method = message.get("method")
|
||||
params = message.get("params")
|
||||
|
||||
# 命令处理逻辑
|
||||
# 原有命令处理逻辑
|
||||
if method == "sync":
|
||||
self.state_mode = 'sync'
|
||||
elif method == "targetFreqs":
|
||||
if method == "targetFreqs":
|
||||
if not isinstance(params, list):
|
||||
algo_log(f"targetFreqs must be a list")
|
||||
return
|
||||
print('targetFreqs must be a list')
|
||||
continue
|
||||
if params != self.targetFreqs:
|
||||
self.targetFreqs = params
|
||||
self.changeTarget = True
|
||||
elif method == "decoderClass":
|
||||
if method == "decoderClass":
|
||||
if not isinstance(params, str):
|
||||
algo_log(f"decoderClass必须是字符串")
|
||||
return
|
||||
print('decoderClass must be a str')
|
||||
continue
|
||||
if params != self.decoder_class:
|
||||
self.decoder_class = params
|
||||
self.decoder_switch = True
|
||||
elif method == "train":
|
||||
if method == "getReport":
|
||||
self.getReport = True
|
||||
if method == "train":#训练状态
|
||||
self.state_mode = 'train'
|
||||
resp = {
|
||||
"method": "train_response",
|
||||
"params": {
|
||||
"code": 200,
|
||||
"message": "ok"
|
||||
}
|
||||
}
|
||||
try:
|
||||
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
|
||||
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
|
||||
algo_log(f"train 命令已即时回复客户端 {ident}", level="DEBUG")
|
||||
except Exception as e:
|
||||
algo_log(f"train 命令回复失败: {e}", level="ERROR")
|
||||
return
|
||||
elif method == "predict":
|
||||
self.StartTrain = True
|
||||
self.currentLabel = params # 当前刺激端的训练标签
|
||||
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||
elif method == "predict":#预测状态
|
||||
self.state_mode = 'predict'
|
||||
if params == 1: #开始解码
|
||||
self.StartDecode = True
|
||||
self.sunnyLinker.push_trigger(0x63)
|
||||
elif params == 2: #停止解码
|
||||
self.IsExitApp = True
|
||||
self.running = False
|
||||
|
||||
resp = {
|
||||
"method": "predict_response",
|
||||
"params": {
|
||||
"code": 200,
|
||||
"message": "ok"
|
||||
}
|
||||
}
|
||||
try:
|
||||
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
|
||||
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
|
||||
algo_log(f"predict 命令已即时回复客户端 {ident}", level="DEBUG")
|
||||
except Exception as e:
|
||||
algo_log(f"predict 命令回复失败: {e}", level="ERROR")
|
||||
return
|
||||
|
||||
elif method == "rest":
|
||||
elif method == "rest": #休息状态
|
||||
self.state_mode = 'rest'
|
||||
elif method == "impedance":
|
||||
if params == 1:
|
||||
self.open_Impedance = True
|
||||
elif params == 2:
|
||||
self.open_Impedance = False
|
||||
else:
|
||||
self.broadcast_message("error", {"code": 404, "message": f"未知命令: {method}"})
|
||||
# elif method == "impedance":
|
||||
# if params == 1:
|
||||
# self.open_Impedance = True # 开启阻抗
|
||||
# self.get_Impedance = True # 返回阻抗
|
||||
# elif params == 2:
|
||||
# self.open_Impedance = False # 关闭阻抗
|
||||
# self.get_Impedance = False # 停止返回阻抗
|
||||
|
||||
# -------------------------- 数据端口消息处理 --------------------------
|
||||
def _handle_data_message(self, frames):
|
||||
"""处理8100端口二进制脑电数据消息"""
|
||||
algo_log(f"收到数据帧,总帧数:{len(frames)}", level="DEBUG", record_once=True)
|
||||
# 然后再进行解析
|
||||
if len(frames) == 4:
|
||||
# 你的上位机格式
|
||||
ident, sender_ident, empty_sep, data_bytes = frames[:4]
|
||||
elif len(frames) == 3:
|
||||
# 标准格式
|
||||
ident, empty_sep, data_bytes = frames[:3]
|
||||
elif len(frames) == 2:
|
||||
ident, data_bytes = frames[:2]
|
||||
else:
|
||||
return
|
||||
# 注册新的数据客户端(单客户端场景,自动覆盖旧身份)
|
||||
if ident not in self.data_clients:
|
||||
self.data_clients.clear() # 单客户端,只保留最新连接
|
||||
self.data_clients.add(ident)
|
||||
self.current_data_client = ident
|
||||
algo_log(f"新数据客户端连接成功: {ident}", level="INFO")
|
||||
try:
|
||||
# 精确长度校验
|
||||
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * np.dtype(np.float64).itemsize
|
||||
if len(data_bytes) != EXPECTED_BYTES:
|
||||
algo_log(f"数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
|
||||
"""
|
||||
处理8100端口原始脑电二进制数据
|
||||
固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区
|
||||
"""
|
||||
# 1. 校验ZMQ消息帧完整性
|
||||
if len(frames) < 3:
|
||||
print(f"[ERROR] 无效数据帧:长度不足3帧,实际长度={len(frames)}")
|
||||
return
|
||||
|
||||
# 零拷贝解析 + 维度转换
|
||||
data_np = np.frombuffer(data_bytes, dtype=np.float64)
|
||||
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
|
||||
ident, _, data_bytes = frames[:3]
|
||||
|
||||
# 2. 客户端管理(单客户端场景,自动更新最新身份)
|
||||
if ident not in self.data_clients:
|
||||
self.data_clients.add(ident)
|
||||
self.current_data_client = ident # 保存唯一客户端身份,用于后续回复滤波结果
|
||||
print(f"[INFO] 新数据客户端连接成功:{ident}")
|
||||
|
||||
try:
|
||||
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节,与int32字节数相同)
|
||||
EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节
|
||||
if len(data_bytes) != EXPECTED_BYTES:
|
||||
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节")
|
||||
return
|
||||
|
||||
# 4. 零拷贝二进制解析 + 维度转换
|
||||
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
|
||||
data_np = np.frombuffer(data_bytes, dtype=np.float32)
|
||||
# 重塑为上位机原始维度
|
||||
data_np = data_np.reshape(5, 66)
|
||||
# 转置为(通道数, 采样点数)标准格式,转换为float64保证滤波运算精度
|
||||
data_np = data_np.T.astype(np.float64)
|
||||
|
||||
# 写入滤波缓冲区
|
||||
with self.filterBufferLock:
|
||||
# 5. 同时写入双环形缓冲区(方法名与现有类保持一致:appendBuffer)
|
||||
# 注意:上位机已发送微伏物理量,无需再乘以增益系数
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
self.filterBuffer.appendBuffer(data_np)
|
||||
|
||||
# 写入范式缓冲区
|
||||
with self.paradigmBufferLock:
|
||||
if self.interval_inited:
|
||||
self.epoch_finished = self.detect_event(data_np)
|
||||
if self.pack_contain_event:
|
||||
self.paradigmBuffer.resetAllPara()
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
|
||||
if self.epoch_finished:
|
||||
now = datetime.datetime.now()
|
||||
time_diff_str = ""
|
||||
# 计算与上一次Epoch完成的时间差
|
||||
if self.last_epoch_finish_time is not None:
|
||||
# 时间差 单位:秒,保留3位小数
|
||||
delta_seconds = (now - self.last_epoch_finish_time).total_seconds()
|
||||
time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s"
|
||||
|
||||
# 拼接日志,增加时间差信息
|
||||
log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}"
|
||||
algo_log(log_msg, level="DEBUG")
|
||||
|
||||
# 更新上一次Epoch完成时间为当前时间
|
||||
self.last_epoch_finish_time = now
|
||||
else:
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
# 生产环境必须注释!每秒50次打印会导致CPU占用飙升30%以上
|
||||
algo_log(f"数据写入成功:shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG", record_once=True)
|
||||
|
||||
except Exception as e:
|
||||
algo_log(f"数据处理失败: {str(e)}", level="ERROR")
|
||||
if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG':
|
||||
algo_log(f"数据处理失败:{str(e)}", level="ERROR")
|
||||
# 调试阶段临时打开,生产环境务必注释
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# -------------------------- 事件检测 --------------------------
|
||||
def detect_event(self, samples):
|
||||
self.pack_contain_event = False
|
||||
# 第65通道为事件通道
|
||||
events = np.array(samples[-2], dtype=np.int32).tolist()
|
||||
for idx, event in enumerate(events):
|
||||
if event in self.events:
|
||||
new_key = "".join(
|
||||
[
|
||||
str(event),
|
||||
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||
-%H-%M-%S"),
|
||||
]
|
||||
)
|
||||
self.currentLabel = event
|
||||
if event == self.predict_event:
|
||||
self.count_events[new_key] = self.latency + 1
|
||||
def _process_send_queue(self):
|
||||
"""处理发送队列,向所有命令客户端广播消息"""
|
||||
while not self.send_queue.empty():
|
||||
method, params = self.send_queue.get()
|
||||
if self.cmd_clients:
|
||||
try:
|
||||
msg = {'method': method, 'params': params}
|
||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||
|
||||
# 打印日志(隐藏大尺寸数据)
|
||||
if method in ['single_trial_plot', 'miReport']:
|
||||
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
|
||||
else:
|
||||
self.count_events[new_key] = self.train_latency + 1
|
||||
self.event_inner_idx = idx
|
||||
algo_log(f"事件检测到: {events},索引: {idx}", level="DEBUG")
|
||||
self.pack_contain_event = True
|
||||
print(f"Sending CMD message: {msg}")
|
||||
|
||||
# 倒计时并清理过期事件
|
||||
drop_items = []
|
||||
for key, value in self.count_events.items():
|
||||
value -= 1
|
||||
if value == 0:
|
||||
drop_items.append(key)
|
||||
self.count_events[key] = value
|
||||
# 广播到所有命令客户端
|
||||
for client_id in list(self.cmd_clients):
|
||||
try:
|
||||
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
|
||||
except Exception as e:
|
||||
print(f"Error sending to CMD client {client_id}: {e}")
|
||||
self.cmd_clients.discard(client_id) # 移除失效客户端
|
||||
except Exception as e:
|
||||
print(f"Error preparing broadcast: {e}")
|
||||
|
||||
for key in drop_items:
|
||||
del self.count_events[key]
|
||||
|
||||
if drop_items:
|
||||
return True
|
||||
return False
|
||||
# -------------------------- 主循环 --------------------------
|
||||
def run(self):
|
||||
self.running = True
|
||||
algo_log(f"ZMQ服务器启动成功 - host: {self.host}, 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
|
||||
print(f"ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
# 1. 处理两个端口的发送队列(必须在主线程执行)
|
||||
self._process_cmd_send_queue()
|
||||
self._process_data_send_queue()
|
||||
# 1. 处理发送队列(命令端口广播)
|
||||
self._process_send_queue()
|
||||
|
||||
# 2. 轮询监听两个端口的输入事件
|
||||
socks = dict(self.poller.poll(50))
|
||||
# 2. 轮训监听两个Socket的输入事件(10ms超时,避免阻塞)
|
||||
socks = dict(self.poller.poll(10))
|
||||
|
||||
# 处理8099命令端口消息
|
||||
# 处理命令端口消息
|
||||
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
||||
frames = self.cmd_socket.recv_multipart()
|
||||
self._handle_cmd_message(frames)
|
||||
|
||||
# 处理8100数据端口消息(排空积压,消除标签延迟)
|
||||
# 处理数据端口消息
|
||||
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
||||
while True:
|
||||
try:
|
||||
frames = self.data_socket.recv_multipart(zmq.NOBLOCK)
|
||||
frames = self.data_socket.recv_multipart()
|
||||
self._handle_data_message(frames)
|
||||
except zmq.Again:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
algo_log(f"服务器主循环异常: {str(e)}", level="ERROR")
|
||||
return
|
||||
print(f"Server error occurred: {e}")
|
||||
finally:
|
||||
self.running = False
|
||||
# 优雅关闭所有资源
|
||||
# 关闭所有Socket和上下文
|
||||
self.cmd_socket.close()
|
||||
self.data_socket.close()
|
||||
self.context.term()
|
||||
algo_log("ZMQ服务器已关闭", level="INFO")
|
||||
print("Server sockets and context closed.")
|
||||
|
||||
def stop(self):
|
||||
"""显式关闭服务器"""
|
||||
@@ -451,10 +243,10 @@ class zmqServer(threading.Thread):
|
||||
self.cmd_socket.close()
|
||||
self.data_socket.close()
|
||||
self.context.term()
|
||||
algo_log(f"服务器已显式关闭 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
|
||||
print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 初始化并启动服务器
|
||||
# 初始化并启动服务器(默认cmd=8099, data=8100)
|
||||
server = zmqServer()
|
||||
server.start()
|
||||
|
||||
@@ -463,5 +255,5 @@ if __name__ == '__main__':
|
||||
while server.running:
|
||||
threading.Event().wait(1)
|
||||
except KeyboardInterrupt:
|
||||
algo_log("收到键盘中断信号,正在停止服务器...", level="INFO")
|
||||
print("Received KeyboardInterrupt, stopping server...")
|
||||
server.stop()
|
||||
445
Zmq/zmqServer1.py
Normal file
445
Zmq/zmqServer1.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import numpy as np
|
||||
import zmq
|
||||
import threading
|
||||
import json
|
||||
import queue
|
||||
import time
|
||||
from Device.SunnyLinker import SunnyLinker64, RingBuffer
|
||||
from collections import deque
|
||||
|
||||
|
||||
class zmqServer(threading.Thread):
|
||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100):
|
||||
threading.Thread.__init__(self)
|
||||
self.host = host
|
||||
self.cmd_port = cmd_port
|
||||
self.data_port = data_port
|
||||
self.running = False
|
||||
self.get_Impedance = False
|
||||
self.open_Impedance = None
|
||||
self.StartDecode = False
|
||||
self.StartTrain = False
|
||||
self.state_mode = None
|
||||
self.currentLabel = -1
|
||||
self.IsExitApp = False
|
||||
self.getReport = False
|
||||
self.daemon = True
|
||||
|
||||
# ZMQ Context
|
||||
self.context = zmq.Context()
|
||||
|
||||
# 指令通道 (8099) - ROUTER
|
||||
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
||||
self.cmd_socket.setsockopt(zmq.RCVHWM, 1000)
|
||||
self.cmd_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
||||
|
||||
# 数据通道 (8100)) - ROUTER
|
||||
self.data_socket = self.context.socket(zmq.ROUTER)
|
||||
self.data_socket.setsockopt(zmq.RCVHWM, 1000)
|
||||
self.data_socket.setsockopt(zmq.RCVTIMEO, 50)
|
||||
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
||||
|
||||
self.targetFreqs = []
|
||||
self.changeTarget = False
|
||||
self.sunnyLinker = SunnyLinker64(None, None, None, None, None)
|
||||
self.labels = [0x01, 0x02, 0x03]
|
||||
|
||||
self.decoder_switch = False
|
||||
self.decoder_class = None
|
||||
self.cmd_clients = set()
|
||||
self.data_clients = set()
|
||||
self.send_queue = queue.Queue()
|
||||
|
||||
# ========== 数据缓冲区 (RingBuffer) ==========
|
||||
# 与 SunnyLinker 保持一致,使用 RingBuffer
|
||||
# 66 = 64 EEG通道 + 1 事件通道(第65) + 1 标签序号通道(第66)
|
||||
# 缓存约 10 秒数据 (250Hz * 10s = 2500 点)
|
||||
self.n_chan = 66
|
||||
self.t_buffer = 10.0 # 缓冲区时长(秒)
|
||||
self.__ringBuffer = RingBuffer(self.n_chan, int(self.t_buffer * 250))
|
||||
|
||||
# 事件检测相关
|
||||
self._event_lock = threading.Lock()
|
||||
self._epoch_finished = False
|
||||
self._event_inner_idx = -1
|
||||
self.pack_contain_event = False
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.count_events = {}
|
||||
self.latency = 50
|
||||
self.train_latency = 50
|
||||
|
||||
# 当前事件标签序号 (从第66通道获取)
|
||||
self.current_label_index = 0
|
||||
|
||||
# 初始化标志
|
||||
self._interval_inited = False
|
||||
self._currentLabel = -1
|
||||
|
||||
# 注册的客户端(兼容旧接口)
|
||||
self.clients = set()
|
||||
|
||||
# ========== 事件属性:线程安全访问 ==========
|
||||
@property
|
||||
def epoch_finished(self):
|
||||
with self._event_lock:
|
||||
return self._epoch_finished
|
||||
|
||||
@epoch_finished.setter
|
||||
def epoch_finished(self, value):
|
||||
with self._event_lock:
|
||||
self._epoch_finished = value
|
||||
|
||||
@property
|
||||
def event_inner_idx(self):
|
||||
with self._event_lock:
|
||||
return self._event_inner_idx
|
||||
|
||||
@event_inner_idx.setter
|
||||
def event_inner_idx(self, value):
|
||||
with self._event_lock:
|
||||
self._event_inner_idx = value
|
||||
|
||||
@property
|
||||
def interval_inited(self):
|
||||
return self._interval_inited
|
||||
|
||||
@interval_inited.setter
|
||||
def interval_inited(self, value):
|
||||
self._interval_inited = value
|
||||
|
||||
@property
|
||||
def currentLabel(self):
|
||||
return self._currentLabel
|
||||
|
||||
@currentLabel.setter
|
||||
def currentLabel(self, value):
|
||||
self._currentLabel = value
|
||||
|
||||
def broadcast_message(self, method, params):
|
||||
"""Put message into queue to be sent to all connected clients"""
|
||||
self.send_queue.put((method, params))
|
||||
|
||||
# ========== 数据缓冲区操作接口 ==========
|
||||
def GetDataLenCount(self):
|
||||
"""返回缓冲区当前数据点数"""
|
||||
return self.__ringBuffer.nUpdate
|
||||
|
||||
def getData(self, count):
|
||||
"""获取最新count个数据点,不消费(只读)"""
|
||||
with self.__ringBuffer.RingBufferLock:
|
||||
count = min(count, self.__ringBuffer.nUpdate)
|
||||
if count == 0:
|
||||
return np.zeros((self.n_chan, 0))
|
||||
|
||||
# 计算读取范围(从尾部取最新数据)
|
||||
read_end = (self.__ringBuffer.currentPtr - 1) % self.__ringBuffer.n_points
|
||||
read_start = (read_end - count + 1) % self.__ringBuffer.n_points
|
||||
|
||||
if self.__ringBuffer.currentPtr == 0:
|
||||
read_start = self.__ringBuffer.n_points - count
|
||||
read_end = self.__ringBuffer.n_points - 1
|
||||
|
||||
if read_start <= read_end:
|
||||
data = self.__ringBuffer.buffer[:, read_start:read_end + 1]
|
||||
else:
|
||||
part1 = self.__ringBuffer.buffer[:, read_start:]
|
||||
part2 = self.__ringBuffer.buffer[:, :read_end + 1]
|
||||
data = np.concatenate((part1, part2), axis=1)
|
||||
|
||||
return data
|
||||
|
||||
def consumeData(self, count):
|
||||
"""消费(丢弃)指定数量的数据点,从头部移除"""
|
||||
with self.__ringBuffer.RingBufferLock:
|
||||
count = min(count, self.__ringBuffer.nUpdate)
|
||||
self.__ringBuffer.readPtr = (self.__ringBuffer.readPtr + count) % self.__ringBuffer.n_points
|
||||
self.__ringBuffer.nUpdate -= count
|
||||
|
||||
def ResetAll(self):
|
||||
"""重置缓冲区"""
|
||||
with self.__ringBuffer.RingBufferLock:
|
||||
self.__ringBuffer.resetAllPara()
|
||||
with self._event_lock:
|
||||
self._epoch_finished = False
|
||||
self._event_inner_idx = -1
|
||||
self.pack_contain_event = False
|
||||
self.count_events.clear()
|
||||
self.current_label_index = 0
|
||||
|
||||
def reset_data_buffer(self):
|
||||
self.ResetAll()
|
||||
|
||||
def reset_state(self):
|
||||
self.ResetAll()
|
||||
|
||||
def interval_init(self, decoder_class):
|
||||
"""初始化事件检测参数"""
|
||||
import ast
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
|
||||
if decoder_class == 'ssmvep':
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||
self.interval_epoch = [int(i * 250) for i in interval_epoch]
|
||||
self.train_epoch = [int(self.interval_epoch[0]),
|
||||
int(self.interval_epoch[1] + 0.1 * 250)]
|
||||
self.latency = (self.interval_epoch[1] + 0.1 * 250) // 5
|
||||
self.train_latency = (self.train_epoch[1] + 0.1 * 250) // 5
|
||||
|
||||
elif decoder_class == 'mi':
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||
self.interval_epoch = [int(i * 250) for i in interval_epoch]
|
||||
self.train_epoch = self.interval_epoch.copy()
|
||||
self.latency = self.interval_epoch[1] // 5
|
||||
self.train_latency = self.latency
|
||||
|
||||
self.count_events = {}
|
||||
self._event_inner_idx = -1
|
||||
self._epoch_finished = False
|
||||
self.pack_contain_event = False
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self._interval_inited = True
|
||||
|
||||
# ========== 事件检测 ==========
|
||||
def detect_event(self, data_matrix):
|
||||
"""
|
||||
检测事件通道中的触发信号
|
||||
|
||||
@param data_matrix: shape (66, N) - N个采样点的数据
|
||||
第65行(索引64) = 事件通道
|
||||
第66行(索引65) = 标签通道
|
||||
@return: 是否检测到事件
|
||||
"""
|
||||
if data_matrix.shape[1] == 0:
|
||||
return False
|
||||
|
||||
self.pack_contain_event = False
|
||||
event_channel = data_matrix[64, :] # 第65通道 = 标签值(event值)
|
||||
label_channel = data_matrix[65, :] # 第66通道 = 标签序号(label index)
|
||||
|
||||
events = event_channel.tolist()
|
||||
|
||||
with self._event_lock:
|
||||
self._event_inner_idx = -1
|
||||
self.current_event_label = 0
|
||||
|
||||
for idx, event in enumerate(events):
|
||||
if int(event) in self.events:
|
||||
self._event_inner_idx = idx
|
||||
self.current_label_index = int(label_channel[idx])
|
||||
self.pack_contain_event = True
|
||||
|
||||
new_key = f"{event}_{time.time()}"
|
||||
latency = self.latency if event == self.predict_event else self.train_latency
|
||||
self.count_events[new_key] = latency + 1
|
||||
|
||||
# 延迟计数递减
|
||||
drop_items = []
|
||||
for key, value in self.count_events.items():
|
||||
value = value - 1
|
||||
if value == 0:
|
||||
drop_items.append(key)
|
||||
self.count_events[key] = value
|
||||
for key in drop_items:
|
||||
del self.count_events[key]
|
||||
|
||||
if drop_items:
|
||||
self._epoch_finished = True
|
||||
# 检测到事件时,清除RingBuffer中之前的数据,只保留当前包
|
||||
if self.pack_contain_event:
|
||||
self.__ringBuffer.resetAllPara()
|
||||
return True
|
||||
|
||||
self._epoch_finished = False
|
||||
return False
|
||||
|
||||
def run(self):
|
||||
self.running = True
|
||||
print(f"Server running - CMD: {self.cmd_port}, DATA: {self.data_port}")
|
||||
|
||||
cmd_poller = zmq.Poller()
|
||||
cmd_poller.register(self.cmd_socket, zmq.POLLIN)
|
||||
|
||||
data_poller = zmq.Poller()
|
||||
data_poller.register(self.data_socket, zmq.POLLIN)
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
# --- 处理发送队列 (指令通道) ---
|
||||
while not self.send_queue.empty():
|
||||
method, params = self.send_queue.get()
|
||||
if self.cmd_clients:
|
||||
try:
|
||||
msg = {'method': method, 'params': params}
|
||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||
for client_id in list(self.cmd_clients):
|
||||
try:
|
||||
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- 处理指令通道 ---
|
||||
socks = dict(cmd_poller.poll(10))
|
||||
if self.cmd_socket in socks:
|
||||
self._handle_cmd_socket()
|
||||
|
||||
# --- 处理数据通道 ---
|
||||
socks = dict(data_poller.poll(10))
|
||||
if self.data_socket in socks:
|
||||
self._handle_data_socket()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Server error: {e}")
|
||||
finally:
|
||||
self.running = False
|
||||
self.cmd_socket.close()
|
||||
self.data_socket.close()
|
||||
self.context.term()
|
||||
|
||||
def _handle_cmd_socket(self):
|
||||
"""处理指令通道消息"""
|
||||
try:
|
||||
frames = self.cmd_socket.recv_multipart()
|
||||
if len(frames) < 3:
|
||||
return
|
||||
ident, _, message_bytes = frames[:3]
|
||||
self.cmd_clients.add(ident)
|
||||
self.clients.add(ident)
|
||||
|
||||
message = json.loads(message_bytes.decode('utf-8'))
|
||||
method = message.get("method")
|
||||
params = message.get("params")
|
||||
|
||||
print(f"[CMD] {method}: {params}")
|
||||
|
||||
if method == "sync":
|
||||
self.state_mode = 'sync'
|
||||
elif method == "targetFreqs":
|
||||
if isinstance(params, list) and params != self.targetFreqs:
|
||||
self.targetFreqs = params
|
||||
self.changeTarget = True
|
||||
elif method == "decoderClass":
|
||||
if isinstance(params, str) and params != self.decoder_class:
|
||||
self.decoder_class = params
|
||||
self.decoder_switch = True
|
||||
elif method == "getReport":
|
||||
self.getReport = True
|
||||
elif method == "train":
|
||||
self.state_mode = 'train'
|
||||
self.StartTrain = True
|
||||
self.currentLabel = params
|
||||
elif method == "predict":
|
||||
self.state_mode = 'predict'
|
||||
if params == 1:
|
||||
self.StartDecode = True
|
||||
elif params == 2:
|
||||
self.IsExitApp = True
|
||||
self.running = False
|
||||
elif method == "rest":
|
||||
self.state_mode = 'rest'
|
||||
elif method == "impedance":
|
||||
if params == 1:
|
||||
self.open_Impedance = True
|
||||
self.get_Impedance = True
|
||||
elif params == 2:
|
||||
self.open_Impedance = False
|
||||
self.get_Impedance = False
|
||||
|
||||
except Exception as e:
|
||||
print(f"CMD socket error: {e}")
|
||||
|
||||
def _handle_data_socket(self):
|
||||
"""处理数据通道消息 (EEG数据)
|
||||
|
||||
上位机数据格式:
|
||||
- 数据帧: [identity, '', meta_json, data_buffer]
|
||||
data_buffer = [N, 66] float32 -> 转置为 [66, N]
|
||||
"""
|
||||
try:
|
||||
frames = self.data_socket.recv_multipart()
|
||||
if len(frames) < 4:
|
||||
return
|
||||
ident, _, message_bytes = frames[:3]
|
||||
self.data_clients.add(ident)
|
||||
|
||||
meta = json.loads(message_bytes.decode('utf-8'))
|
||||
|
||||
# data: [N, 66] -> 转置 -> [66, N]
|
||||
raw_data = np.frombuffer(frames[3], dtype=np.float32)
|
||||
n_samples, n_channels = meta.get('shape', [5, 66])
|
||||
data_matrix = raw_data.reshape(n_samples, n_channels).T.astype(np.float32)
|
||||
|
||||
# 写入 RingBuffer
|
||||
with self.__ringBuffer.RingBufferLock:
|
||||
self.__ringBuffer.appendBuffer(data_matrix)
|
||||
|
||||
# 事件检测
|
||||
self.detect_event(data_matrix)
|
||||
|
||||
except Exception as e:
|
||||
print(f"DATA socket error: {e}")
|
||||
|
||||
# ========== 各范式数据访问接口 ==========
|
||||
def get_MIData(self):
|
||||
"""获取MI导联数据 (21通道 + 事件)"""
|
||||
data = self.getData(self.GetDataLenCount())
|
||||
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_SSMVEPData(self):
|
||||
"""获取SSMVEP导联数据 (8通道 + 事件)"""
|
||||
data = self.getData(self.GetDataLenCount())
|
||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def getDataViaSSVEP(self, count):
|
||||
"""获取SSVEP数据 (8通道 + 事件)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_concentrateData(self, count):
|
||||
"""获取专注力数据 (2通道)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [0, 1]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_blinkData(self, count):
|
||||
"""获取眨眼数据 (2通道)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [0, 1]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def getImpedance(self, data, decoder_class):
|
||||
"""计算阻抗(ZMQ模式下不可用)"""
|
||||
return np.zeros(8)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.cmd_socket.close()
|
||||
self.data_socket.close()
|
||||
self.context.term()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
server = zmqServer()
|
||||
server.start()
|
||||
@@ -8,7 +8,6 @@ import os
|
||||
# import logging
|
||||
import base64
|
||||
import io
|
||||
import math
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
#
|
||||
@@ -23,7 +22,7 @@ import math
|
||||
|
||||
|
||||
class Calculate():
|
||||
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10, config=None):
|
||||
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10):
|
||||
self.Threshold_value_low = Threshold_value_low
|
||||
self.Threshold_value_high = Threshold_value_high
|
||||
self.fs = fs
|
||||
@@ -32,73 +31,47 @@ class Calculate():
|
||||
self.EVI_result = []
|
||||
self.eegQueue = deque(maxlen=win_len)
|
||||
|
||||
# # 存储历史数据用于绘图
|
||||
# self.beta_history = []
|
||||
# self.alpha_history = []
|
||||
# self.theta_history = []
|
||||
# self.focus_history = []
|
||||
# self.timestamp_history = []
|
||||
#
|
||||
# # 记录开始时间
|
||||
# self.start_time = None
|
||||
# self.recording = False
|
||||
#
|
||||
# # 图表保存路径
|
||||
# self.chart_dir = "reports"
|
||||
# if not os.path.exists(self.chart_dir):
|
||||
# os.makedirs(self.chart_dir)
|
||||
# print(f"[调试] 创建目录: {self.chart_dir}")
|
||||
|
||||
# 初始化滤波器
|
||||
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
|
||||
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
|
||||
|
||||
self.last_focus = None
|
||||
# 异步滤波系数配置(核心手感控制纽)
|
||||
self.alpha_up = 1 # 上升系数:较小,保证分数平滑爬升,过滤偶发的瞬时高能量
|
||||
# alpha_down / shrink_factor 从 config.ini 读取,方便上位机调参
|
||||
if config:
|
||||
self.alpha_down = float(config.get('alpha_down', 0.8))
|
||||
self.shrink_factor = float(config.get('shrink_factor', 0.5))
|
||||
else:
|
||||
self.alpha_down = 0.8
|
||||
self.shrink_factor = 0.5
|
||||
print("[调试] Calculate 类初始化完成")
|
||||
|
||||
def calculate_focus(self, beta, alpha, theta):
|
||||
"""
|
||||
专注度计算 - 三区间门限异步滤波版本
|
||||
专注度计算 - 固定映射版本
|
||||
"""
|
||||
# 0. 频带特征预处理
|
||||
theta_mod = theta ** 0.7
|
||||
|
||||
# 原始比值
|
||||
raw = beta / (alpha + theta_mod + 1e-10)
|
||||
raw = beta / (alpha + theta + 1e-10)
|
||||
|
||||
exponent = 2.0
|
||||
# Sigmoid 映射:让 raw 在 0.3-1.5 区间敏感
|
||||
# 参数可调:
|
||||
# k = 12 (斜率,越大越陡)
|
||||
# x0 = 0.6 (中心点,raw=0.6时focus≈50)
|
||||
k = 12.0
|
||||
x0 = 0.6
|
||||
focus = 100.0 / (1.0 + np.exp(-k * (raw - x0)))
|
||||
|
||||
# 1. 防止脑电比值出现负数异常值
|
||||
raw_input = max(raw, 0.0)
|
||||
|
||||
# 2. 2次幂纵轴压缩映射 (shrink_factor 从 config.ini 读取)
|
||||
focus_raw = 100 * self.shrink_factor * (raw_input ** exponent)
|
||||
|
||||
# 3. 计算当前帧的瞬时分数 (基准量级 0-120)
|
||||
instant_focus = 120 * (1.0 - np.exp(-focus_raw / 100.0))
|
||||
|
||||
# 4. 核心修改:三区间门限时域滤波
|
||||
if self.last_focus is None:
|
||||
# 冷启动:首帧直接赋值
|
||||
focus = instant_focus
|
||||
else:
|
||||
# 判断当前瞬时分数是否处于【极端区】(80以上 或 60以下)
|
||||
if instant_focus > 85.0 or instant_focus < 60.0:
|
||||
# 执行异步低通时域滤波
|
||||
if instant_focus >= self.last_focus:
|
||||
# 趋势上升:慢爬升
|
||||
focus = self.alpha_up * instant_focus + (1 - self.alpha_up) * self.last_focus
|
||||
else:
|
||||
# 趋势下降:快跌落
|
||||
focus = self.alpha_down * instant_focus + (1 - self.alpha_down) * self.last_focus
|
||||
else:
|
||||
# 【高灵敏自由区】(60 <= instant_focus <= 80)
|
||||
# 不执行异步滤波,分数直接跟随瞬时值,保证中间状态绝对跟手
|
||||
focus = instant_focus
|
||||
|
||||
# 5. 更新历史状态缓存
|
||||
self.last_focus = focus
|
||||
|
||||
# 打印在线调试日志,方便观察区间切换
|
||||
zone_tag = "极端区(滤波)" if (instant_focus > 80 or instant_focus < 60) else "自由区(直通)"
|
||||
print(f"原始特征比值 raw: {raw:.4f} | 瞬时分数: {instant_focus:.1f} | 滤波后分数: {focus:.1f}")
|
||||
|
||||
# 最终返回整型
|
||||
# 可选:添加滑动平均平滑
|
||||
return int(focus)
|
||||
|
||||
|
||||
def calculate_all(self, data, fs, nperseg=1000):
|
||||
mean_x = np.mean(data, axis=-1, keepdims=True)
|
||||
data = data - mean_x
|
||||
@@ -346,16 +319,14 @@ class Calculate():
|
||||
if eegData.size == 0:
|
||||
return None
|
||||
eegData -= np.mean(eegData, axis=-1, keepdims=True)
|
||||
# eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) # 陷波
|
||||
# eegData = signal.lfilter(self.b_design, 1, eegData) # 滤波
|
||||
focus_score, CLI_score, beta_psd, alpha_psd, theta_psd = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
|
||||
eegData = signal.lfilter(self.b_notch, self.a_notch, eegData)
|
||||
eegData = signal.lfilter(self.b_design, 1, eegData)
|
||||
focus_score, CLI_score, beta, alpha, theta = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
|
||||
|
||||
# self.add_data_point(focus_score, beta_psd, alpha_psd, theta_psd) # 已注释(方法已移除)
|
||||
|
||||
# return (focus_score)
|
||||
return (focus_score, beta_psd)
|
||||
# return None
|
||||
# self.add_data_point(focus_score, beta, alpha, theta)
|
||||
|
||||
return focus_score
|
||||
return None
|
||||
|
||||
|
||||
class Calculate2():
|
||||
|
||||
19
config.ini
19
config.ini
@@ -15,19 +15,14 @@ Audio_device = 0
|
||||
Rest_time = 2
|
||||
Upper_Host = 127.0.0.1
|
||||
Upper_Port = 8088
|
||||
Decoder_Host = 127.0.0.1
|
||||
Decoder_Port = 8099
|
||||
Serial_port = COM44
|
||||
algo_log_path = d:/Program Files/64chn_Decoder/logs
|
||||
algo_log_level = DEBUG
|
||||
console_output = 1
|
||||
save_train_data = 0
|
||||
zmqServer_host = 10.200.27.140
|
||||
|
||||
; 64 导设备配置
|
||||
[device_type_1]
|
||||
sample_rate = 250
|
||||
frame_points = 5
|
||||
channel_nums = 66
|
||||
channel_names = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1', 'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6', 'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1', 'T7', 'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ', 'PZ', 'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1', 'label', 'label_tag']
|
||||
channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
|
||||
|
||||
; 64 导设备配置 1; 32 2; 24 3; 16 4; 8 5; 4 6;
|
||||
[device_type] = 1
|
||||
device_sample_rate = 250
|
||||
device_channel_nums = 66
|
||||
device_channel_names = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1', 'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6', 'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1', 'T7', 'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ', 'PZ', 'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1', 'label', 'label_tag']
|
||||
device_channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
|
||||
|
||||
188
datamock.py
188
datamock.py
@@ -1,188 +0,0 @@
|
||||
import zmq
|
||||
import numpy as np
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
# ========== 参数配置 ==========
|
||||
FS = 250 # 采样率 Hz
|
||||
N_SAMPLES_PER_PKT = 5 # 每包采样点数
|
||||
N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
|
||||
EEG_FREQ = 10 # EEG 正弦波频率 Hz
|
||||
EEG_AMP = 100.0 # EEG 幅值 100μV
|
||||
LABEL_INTERVAL = 5 # 标签间隔秒数
|
||||
SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
||||
LABEL_CMD_ADDR = 'tcp://127.0.0.1:8101' # 接收来自上位机范式的标签命令
|
||||
|
||||
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
|
||||
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
|
||||
|
||||
|
||||
def build_packet(global_sample_idx):
|
||||
"""
|
||||
生成一包 [5, 66] 的 float64 数据
|
||||
:param global_sample_idx: 当前包第一个采样点在全局序列中的索引 (从 0 开始)
|
||||
:return: np.ndarray shape [5, 66]
|
||||
"""
|
||||
# 当前包内 5 个采样点对应的时间(秒)
|
||||
t = (global_sample_idx + np.arange(N_SAMPLES_PER_PKT)) / FS
|
||||
|
||||
# Ch0-63: EEG 10Hz 正弦波,幅值 100μV
|
||||
# t shape [5,],sin 乘以标量后仍是 [5,],需要 reshape 为 [5,1] 再广播到 64 通道
|
||||
eeg = (EEG_AMP * np.sin(2 * np.pi * EEG_FREQ * t)).reshape(N_SAMPLES_PER_PKT, 1) # [5, 1]
|
||||
eeg = np.tile(eeg, (1, 64)) # [5, 64]
|
||||
|
||||
# Ch64: 标签值通道,初始化为 0
|
||||
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
||||
|
||||
# Ch65: 标签序号通道,初始化为 0
|
||||
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
|
||||
|
||||
# 拼成 [5, 66]
|
||||
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float64)
|
||||
return packet
|
||||
|
||||
|
||||
def should_send_label(global_sample_idx):
|
||||
"""
|
||||
判断当前包是否包含标签触发点(每 5s 的最后一个采样点)
|
||||
采样点索引从 0 开始,每 5s = 1250 个采样点
|
||||
最后一个采样点索引: 1249, 2499, 3749, ...
|
||||
由于每包 5 个采样点,标签点落在包内的最后一个采样点位置
|
||||
即当前包起始索引 global_sample_idx 必须使得:
|
||||
global_sample_idx <= 标签点索引 < global_sample_idx + N_SAMPLES_PER_PKT
|
||||
也就是 global_sample_idx <= 1249 < global_sample_idx + 5
|
||||
即 global_sample_idx = 1245, 2495, 3745, ...
|
||||
即 global_sample_idx = n * LABEL_INTERVAL * FS - N_SAMPLES_PER_PKT
|
||||
"""
|
||||
samples_per_interval = LABEL_INTERVAL * FS
|
||||
# 检查当前包是否包含 interval 的最后一个采样点
|
||||
# 标签点索引 = n * 1250 - 1,当 global_sample_idx = n*1250-5 时,标签在包内索引 4
|
||||
return (global_sample_idx + N_SAMPLES_PER_PKT - 1) % samples_per_interval == samples_per_interval - 1
|
||||
|
||||
|
||||
def main():
|
||||
ctx = zmq.Context()
|
||||
sock = ctx.socket(zmq.DEALER)
|
||||
sock.connect(SERVER_ADDR)
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}")
|
||||
|
||||
# ========== 上位机标签命令监听 ==========
|
||||
# 使用线程安全的队列接收来自 ssmvep_main.py 的标签命令
|
||||
# 标签值: 1 (train 0), 2 (train 1), 99 (predict)
|
||||
pending_label = [None] # [label_value or None]
|
||||
label_lock = threading.Lock()
|
||||
|
||||
label_cmd_sock = ctx.socket(zmq.PULL)
|
||||
label_cmd_sock.bind(LABEL_CMD_ADDR)
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听绑定到 {LABEL_CMD_ADDR}")
|
||||
|
||||
stop_recv = threading.Event()
|
||||
|
||||
def label_cmd_thread():
|
||||
"""监听来自上位机范式的标签命令,写入 pending_label"""
|
||||
while not stop_recv.is_set():
|
||||
try:
|
||||
msg = label_cmd_sock.recv_string(zmq.NOBLOCK)
|
||||
label_val = int(msg)
|
||||
with label_lock:
|
||||
pending_label[0] = label_val
|
||||
ts = datetime.now().strftime('%H:%M:%S')
|
||||
label_name = {1: 'train_0', 2: 'train_1', 99: 'predict'}.get(label_val, str(label_val))
|
||||
print(f"[{ts}] 收到标签命令: {label_name} -> label={label_val}")
|
||||
except zmq.Again:
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
print(f"[label_cmd_thread] 错误: {e}")
|
||||
time.sleep(0.01)
|
||||
|
||||
label_thread = threading.Thread(target=label_cmd_thread, daemon=True)
|
||||
label_thread.start()
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听线程已启动")
|
||||
|
||||
# 后台消费线程:持续 recv 从 ROUTER 返回的数据,避免 server 发送队列积压
|
||||
recv_count = [0]
|
||||
|
||||
def consumer_thread():
|
||||
"""消费线程:阻塞 recv,丢弃收到的数据,仅用于清空 ROUTER 发送队列"""
|
||||
while not stop_recv.is_set():
|
||||
try:
|
||||
frames = sock.recv_multipart(zmq.NOBLOCK)
|
||||
recv_count[0] += 1
|
||||
# 收到的格式: [identity, '', filtered_data_bytes]
|
||||
if recv_count[0] % 500 == 0:
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已丢弃 {recv_count[0]} 帧滤波数据")
|
||||
except zmq.Again:
|
||||
time.sleep(0.01)
|
||||
except zmq.error.Again: # 兼容旧版
|
||||
time.sleep(0.01)
|
||||
|
||||
consumer = threading.Thread(target=consumer_thread, daemon=True)
|
||||
consumer.start()
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已启动(daemon)")
|
||||
|
||||
global_sample_idx = 0 # 全局采样点计数器
|
||||
label_type = 1 # 当前标签类型: 1 或 2
|
||||
label1_count = 0 # label=1 的序号计数器
|
||||
label2_count = 0 # label=2 的序号计数器
|
||||
packet_count = 0 # 已发送包数
|
||||
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...")
|
||||
print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms")
|
||||
print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV")
|
||||
print(f" 标签: 来自上位机范式命令 (train_0=1, train_1=2, predict=99)")
|
||||
print("-" * 50)
|
||||
|
||||
try:
|
||||
while True:
|
||||
t_start = time.perf_counter()
|
||||
|
||||
# 构建当前包
|
||||
packet = build_packet(global_sample_idx)
|
||||
|
||||
# 检查是否有来自上位机范式的挂起标签命令
|
||||
with label_lock:
|
||||
ext_label = pending_label[0]
|
||||
if ext_label is not None:
|
||||
pending_label[0] = None
|
||||
|
||||
if ext_label is not None:
|
||||
# 将标签写入当前包所有5个采样点的第65通道 (index 64)
|
||||
# 覆盖全部采样点确保 event_inner_idx 无论落在哪个位置都能被正确检测
|
||||
packet[:, 64] = float(ext_label)
|
||||
ts = datetime.now().strftime('%H:%M:%S')
|
||||
print(f"[{ts}] 打标签: label={ext_label} -> ch64[all 5 samples] (global_sample_idx={global_sample_idx})")
|
||||
|
||||
# 发送: multipart 2帧 ['', data]
|
||||
# 使用标准格式,ROUTER 会自动附加 ZMQ 分配的客户端身份
|
||||
sock.send_multipart([
|
||||
b'',
|
||||
packet.tobytes()
|
||||
])
|
||||
|
||||
# 每 50 包打印一次进度
|
||||
if packet_count % 50 == 0:
|
||||
ts = datetime.now().strftime('%H:%M:%S')
|
||||
print(f"[{ts}] 已发送 {packet_count} 包 (global_sample_idx={global_sample_idx})")
|
||||
|
||||
global_sample_idx += N_SAMPLES_PER_PKT
|
||||
packet_count += 1
|
||||
|
||||
# 精确控制发送节奏: 等待到 PKT_INTERVAL 秒
|
||||
elapsed = time.perf_counter() - t_start
|
||||
sleep_time = PKT_INTERVAL - elapsed
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 停止发送,共发送 {packet_count} 包")
|
||||
finally:
|
||||
stop_recv.set()
|
||||
consumer.join(timeout=2)
|
||||
label_cmd_sock.close()
|
||||
sock.close()
|
||||
ctx.term()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
421
filter_test.py
421
filter_test.py
@@ -1,421 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
脑电滤波服务 8100端口测试工具【统计逻辑专项优化版】
|
||||
优化点:
|
||||
1. 5秒预热(250个发包),预热结束后才启动丢包/数据统计
|
||||
2. 业务比例:0.02s发1包,200ms收1包 → 每 10 个发包对应 1 个回包
|
||||
3. 通道校验:发送(5,66) 仅对比前64通道,接收(50,64)全通道比对
|
||||
4. 区分:全局总包数 / 有效统计区间包数、理论收包数、实际收包数、丢包数、丢包率
|
||||
5. 新增64通道整体数据均值/极值比对,校验数据有效性
|
||||
通信规范:send_multipart([client_id, b"", data_buf]) 三帧报文,服务端 recv_multipart 长度=3
|
||||
"""
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
import traceback
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
import zmq
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.animation import FuncAnimation
|
||||
|
||||
# ===================== 全局前置:修复Matplotlib中文字体 & 负号显示 =====================
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "WenQuanYi Micro Hei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
# ===================== 【1. 全局业务固定参数(核心统计规则)】 =====================
|
||||
# ZMQ 服务端配置
|
||||
ZMQ_SERVER_IP = "127.0.0.1"
|
||||
ZMQ_SERVER_PORT = 8100
|
||||
ZMQ_SOCKET_TIMEOUT = 3000 # 套接字超时(ms)
|
||||
POLL_TIMEOUT = 10 # Poll轮询超时(ms)
|
||||
|
||||
# 时序 & 统计核心规则(严格对齐现场业务)
|
||||
SEND_INTERVAL = 0.02 # 上位机发包间隔:20ms/包
|
||||
RECV_INTERVAL = 0.2 # 服务端回包间隔:200ms/包
|
||||
PREHEAT_SECONDS = 5.0 # 滤波缓存预热时长:5秒
|
||||
# 计算:预热需要的发包总数 = 预热时长 / 单包发送间隔
|
||||
PREHEAT_SEND_PACKS = int(PREHEAT_SECONDS / SEND_INTERVAL) # 5 / 0.02 = 250 包
|
||||
# 收发包比例:每多少个发包对应1个回包
|
||||
PACK_RATIO = int(RECV_INTERVAL / SEND_INTERVAL) # 0.2 / 0.02 = 10
|
||||
|
||||
# 数据报文形状
|
||||
PKG_SEND_SHAPE = (5, 66) # 发送包 (点数, 总通道)
|
||||
PKG_RECV_SHAPE = (50, 64) # 回包 (点数, 有效脑电通道)
|
||||
SAMPLE_RATE = 250
|
||||
|
||||
# 通道定义(对比仅使用前64路脑电通道)
|
||||
CH_EEG_VALID = 64 # 共同对比通道数:0~63
|
||||
CH_EVENT = 64
|
||||
CH_RESERVED = 65
|
||||
|
||||
# ZMQ 三帧报文固定字段
|
||||
CLIENT_ID = b"test_client_001"
|
||||
EMPTY_FRAME = b""
|
||||
|
||||
# 仿真信号配置
|
||||
TARGET_CHANNEL = 0
|
||||
SIGNAL_FREQ_LIST = [13]
|
||||
SIGNAL_AMP = 1.8
|
||||
NOISE_GAUSSIAN_AMP = 0.4
|
||||
NOISE_POWER50_AMP = 0.3
|
||||
EVENT_LABEL_VAL = 1
|
||||
RESERVED_VAL = 0.0
|
||||
|
||||
# 可视化配置
|
||||
MAX_PLOT_POINTS = 800
|
||||
PLOT_REFRESH_INTERVAL = 80
|
||||
FFT_N_POINTS = 256
|
||||
PLOT_X_LIMIT_FREQ = (0, 60)
|
||||
|
||||
# 运行控制
|
||||
MAX_RUN_SECONDS = None
|
||||
ENABLE_RECONNECT = True
|
||||
PRINT_STAT_INTERVAL = 5.0
|
||||
|
||||
# ===================== 【2. 全局变量 + 统计结构体(重构统计逻辑)】 =====================
|
||||
g_running = threading.Event()
|
||||
g_running.set()
|
||||
data_lock = threading.Lock()
|
||||
|
||||
# 绘图缓冲区
|
||||
raw_data_buf = deque(maxlen=MAX_PLOT_POINTS)
|
||||
filt_data_buf = deque(maxlen=MAX_PLOT_POINTS)
|
||||
|
||||
# ===================== 全新统计变量(区分预热/正式统计) =====================
|
||||
stat = {
|
||||
# 全局总包数(包含预热包)
|
||||
"total_send": 0,
|
||||
"total_recv": 0,
|
||||
|
||||
# 有效统计区间(预热250包之后)
|
||||
"valid_send": 0, # 有效发包数
|
||||
"valid_recv": 0, # 有效收包数
|
||||
"theo_recv": 0, # 理论应收到包数 = valid_send // PACK_RATIO
|
||||
|
||||
# 运行时间
|
||||
"start_time": time.perf_counter(),
|
||||
"last_print_time": time.perf_counter(),
|
||||
|
||||
# 数据校验缓存:保存最新一包原始64通道数据,用于和回包比对
|
||||
"latest_raw_64ch": None
|
||||
}
|
||||
|
||||
# ===================== 【3. 日志配置】 =====================
|
||||
def init_logger():
|
||||
log_format = "%(asctime)s | %(levelname)-8s | %(message)s"
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=log_format,
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
return logging.getLogger("FilterTest")
|
||||
|
||||
logger = init_logger()
|
||||
|
||||
# ===================== 【4. 仿真脑电数据生成 (5,66)】 =====================
|
||||
def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
|
||||
"""生成单包 (5,66) 仿真数据"""
|
||||
n_point, n_chan = PKG_SEND_SHAPE
|
||||
base_t = pkt_idx * n_point / SAMPLE_RATE
|
||||
t_arr = base_t + np.arange(n_point) / SAMPLE_RATE
|
||||
|
||||
data = np.zeros((n_point, n_chan), dtype=np.float64)
|
||||
|
||||
# 64路脑电信号
|
||||
for ch in range(CH_EEG_VALID):
|
||||
sig = 0.0
|
||||
for freq in SIGNAL_FREQ_LIST:
|
||||
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
|
||||
# sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
|
||||
# sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
|
||||
data[:, ch] = sig
|
||||
|
||||
# 事件通道、保留通道
|
||||
data[:, CH_EVENT] = EVENT_LABEL_VAL
|
||||
data[:, CH_RESERVED] = RESERVED_VAL
|
||||
return data
|
||||
|
||||
# ===================== 【5. ZMQ 核心IO线程(单连接+Poller,保留原有通信逻辑)】 =====================
|
||||
def zmq_io_thread():
|
||||
context = zmq.Context()
|
||||
pkt_index = 0
|
||||
send_interval = SEND_INTERVAL
|
||||
|
||||
logger.info(f"滤波预热配置:{PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 个发包后开始统计")
|
||||
logger.info(f"收发比例:每 {PACK_RATIO} 个发包 → 1 个滤波回包")
|
||||
|
||||
while g_running.is_set():
|
||||
try:
|
||||
sock = context.socket(zmq.DEALER)
|
||||
sock.setsockopt(zmq.RCVTIMEO, ZMQ_SOCKET_TIMEOUT)
|
||||
sock.setsockopt(zmq.SNDTIMEO, ZMQ_SOCKET_TIMEOUT)
|
||||
sock.connect(f"tcp://{ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
||||
logger.info(f"ZMQ 连接成功 -> {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(sock, zmq.POLLIN)
|
||||
next_send_ts = time.perf_counter()
|
||||
|
||||
while g_running.is_set():
|
||||
# 全局运行时长限制
|
||||
if MAX_RUN_SECONDS is not None:
|
||||
run_sec = time.perf_counter() - stat["start_time"]
|
||||
if run_sec > MAX_RUN_SECONDS:
|
||||
logger.info(f"已到达设定运行时长 {MAX_RUN_SECONDS}s,停止任务")
|
||||
return
|
||||
|
||||
# ========== 1. 轮询接收服务端回包 ==========
|
||||
socks_ready = dict(poller.poll(POLL_TIMEOUT))
|
||||
if sock in socks_ready:
|
||||
frames = sock.recv_multipart()
|
||||
if not frames:
|
||||
continue
|
||||
recv_bytes = frames[-1]
|
||||
if not recv_bytes:
|
||||
continue
|
||||
|
||||
# 解析回包 (50,64)
|
||||
filt_data = np.frombuffer(recv_bytes, dtype=np.float64)
|
||||
expect_size = PKG_RECV_SHAPE[0] * PKG_RECV_SHAPE[1]
|
||||
if filt_data.size != expect_size:
|
||||
logger.warning(f"回包长度异常:实际{filt_data.size},预期{expect_size}")
|
||||
continue
|
||||
filt_data = filt_data.reshape(PKG_RECV_SHAPE)
|
||||
|
||||
# 全局收包计数
|
||||
stat["total_recv"] += 1
|
||||
|
||||
# 仅预热完成后,计入有效统计收包
|
||||
if stat["total_send"] > PREHEAT_SEND_PACKS:
|
||||
stat["valid_recv"] += 1
|
||||
|
||||
# 写入绘图缓冲区
|
||||
with data_lock:
|
||||
filt_data_buf.extend(filt_data[:, TARGET_CHANNEL])
|
||||
|
||||
# ---------- 新增:64通道数据比对(发包前64通道 <-> 回包64通道) ----------
|
||||
raw_64ch = stat["latest_raw_64ch"]
|
||||
if raw_64ch is not None:
|
||||
raw_mean = np.mean(raw_64ch)
|
||||
filt_mean = np.mean(filt_data)
|
||||
raw_amp = np.max(np.abs(raw_64ch))
|
||||
filt_amp = np.max(np.abs(filt_data))
|
||||
logger.debug(
|
||||
f"【通道数据比对】原始64通道均值:{raw_mean:.4f} 幅值:{raw_amp:.4f} | "
|
||||
f"滤波后均值:{filt_mean:.4f} 幅值:{filt_amp:.4f}"
|
||||
)
|
||||
|
||||
# ========== 2. 精准定时发送数据包 ==========
|
||||
current_ts = time.perf_counter()
|
||||
if current_ts >= next_send_ts:
|
||||
# 生成(5,66)仿真包
|
||||
pkt_data = generate_eeg_packet(pkt_index)
|
||||
pkt_index += 1
|
||||
send_buf = pkt_data.tobytes()
|
||||
|
||||
# 标准三帧Multipart发送
|
||||
sock.send_multipart([CLIENT_ID, EMPTY_FRAME, send_buf])
|
||||
|
||||
# ---------- 发包计数逻辑(核心优化:预热区分) ----------
|
||||
stat["total_send"] += 1
|
||||
# 预热完成后,计入有效发包
|
||||
if stat["total_send"] > PREHEAT_SEND_PACKS:
|
||||
stat["valid_send"] += 1
|
||||
# 计算理论应收包数
|
||||
stat["theo_recv"] = stat["valid_send"] // PACK_RATIO
|
||||
|
||||
# 缓存当前包前64通道,用于后续数据比对
|
||||
stat["latest_raw_64ch"] = pkt_data[:, :CH_EEG_VALID]
|
||||
|
||||
# 绘图缓冲区(单通道波形)
|
||||
with data_lock:
|
||||
raw_data_buf.extend(pkt_data[:, TARGET_CHANNEL])
|
||||
|
||||
# 更新下一次发包时间
|
||||
next_send_ts += send_interval
|
||||
|
||||
# ========== 3. 定时打印统计信息(区分预热/正式统计) ==========
|
||||
now = time.perf_counter()
|
||||
if now - stat["last_print_time"] > PRINT_STAT_INTERVAL:
|
||||
run_sec = now - stat["start_time"]
|
||||
total_send = stat["total_send"]
|
||||
total_recv = stat["total_recv"]
|
||||
|
||||
# 分支1:仍在预热阶段
|
||||
if total_send <= PREHEAT_SEND_PACKS:
|
||||
remain = PREHEAT_SEND_PACKS - total_send
|
||||
logger.info(
|
||||
f"[预热中] 运行:{run_sec:.1f}s | 已发包:{total_send}/{PREHEAT_SEND_PACKS} | "
|
||||
f"剩余预热包:{remain} | 暂不统计丢包"
|
||||
)
|
||||
# 分支2:预热完成,进入正式统计
|
||||
else:
|
||||
v_send = stat["valid_send"]
|
||||
v_recv = stat["valid_recv"]
|
||||
t_recv = stat["theo_recv"]
|
||||
loss_cnt = t_recv - v_recv
|
||||
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
|
||||
|
||||
logger.info(
|
||||
f"[正式统计] 运行:{run_sec:.1f}s | "
|
||||
f"全局总包: 发{total_send}/收{total_recv} | "
|
||||
f"有效区间: 发{v_send}/应收{t_recv}/实收{v_recv} | "
|
||||
f"丢包数:{loss_cnt} | 丢包率:{loss_rate:.2f}%"
|
||||
)
|
||||
stat["last_print_time"] = now
|
||||
|
||||
except zmq.ZMQError as e:
|
||||
if e.errno == zmq.EAGAIN:
|
||||
continue
|
||||
logger.warning(f"ZMQ 连接异常: {e}")
|
||||
sock.close()
|
||||
poller.unregister(sock)
|
||||
if not ENABLE_RECONNECT:
|
||||
break
|
||||
logger.info("500ms 后尝试重连...")
|
||||
time.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.error(f"IO线程未知异常:\n{traceback.format_exc()}")
|
||||
break
|
||||
|
||||
context.term()
|
||||
logger.info("ZMQ IO 线程已退出")
|
||||
|
||||
# ===================== 【6. 可视化绘图(无改动)】 =====================
|
||||
def init_plot():
|
||||
fig = plt.figure(figsize=(14, 9))
|
||||
fig.suptitle(f"脑电滤波测试 | 观测通道: {TARGET_CHANNEL}", fontsize=14)
|
||||
|
||||
ax1 = plt.subplot(2, 2, 1)
|
||||
ax1.set_title("原始输入波形 (含噪声+工频)")
|
||||
ax1.set_ylabel("幅值")
|
||||
ax1.grid(True, alpha=0.3)
|
||||
line_raw, = ax1.plot([], [], color="#1f77b4", linewidth=1)
|
||||
|
||||
ax2 = plt.subplot(2, 2, 2)
|
||||
ax2.set_title("滤波后输出波形")
|
||||
ax2.set_ylabel("幅值")
|
||||
ax2.grid(True, alpha=0.3)
|
||||
line_filt, = ax2.plot([], [], color="#d62728", linewidth=1)
|
||||
|
||||
ax3 = plt.subplot(2, 2, 3)
|
||||
ax3.set_title("原始信号频谱")
|
||||
ax3.set_xlabel("频率 (Hz)")
|
||||
ax3.set_xlim(*PLOT_X_LIMIT_FREQ)
|
||||
ax3.grid(True, alpha=0.3)
|
||||
line_raw_fft, = ax3.plot([], [], color="#1f77b4")
|
||||
|
||||
ax4 = plt.subplot(2, 2, 4)
|
||||
ax4.set_title("滤波后信号频谱")
|
||||
ax4.set_xlabel("频率 (Hz)")
|
||||
ax4.set_xlim(*PLOT_X_LIMIT_FREQ)
|
||||
ax4.grid(True, alpha=0.3)
|
||||
line_filt_fft, = ax4.plot([], [], color="#d62728")
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
return fig, [line_raw, line_filt, line_raw_fft, line_filt_fft], [ax1, ax2, ax3, ax4]
|
||||
|
||||
def update_plot(frame, lines, axes):
|
||||
line_raw, line_filt, line_raw_fft, line_filt_fft = lines
|
||||
ax1, ax2, ax3, ax4 = axes
|
||||
|
||||
with data_lock:
|
||||
raw_data = list(raw_data_buf)
|
||||
filt_data = list(filt_data_buf)
|
||||
|
||||
if raw_data:
|
||||
x_raw = np.arange(len(raw_data))
|
||||
line_raw.set_data(x_raw, raw_data)
|
||||
ax1.relim()
|
||||
ax1.autoscale_view()
|
||||
|
||||
if filt_data:
|
||||
x_filt = np.arange(len(filt_data))
|
||||
line_filt.set_data(x_filt, filt_data)
|
||||
ax2.relim()
|
||||
ax2.autoscale_view()
|
||||
|
||||
def calc_fft(sig, n_fft):
|
||||
if len(sig) < n_fft:
|
||||
return [], []
|
||||
win = np.hanning(n_fft)
|
||||
sig_win = sig[-n_fft:] * win
|
||||
fft_vals = np.fft.fft(sig_win)
|
||||
fft_amp = np.abs(fft_vals)[:n_fft//2]
|
||||
freq = np.fft.fftfreq(n_fft, 1/SAMPLE_RATE)[:n_fft//2]
|
||||
return freq, fft_amp
|
||||
|
||||
freq_raw, amp_raw = calc_fft(raw_data, FFT_N_POINTS)
|
||||
freq_filt, amp_filt = calc_fft(filt_data, FFT_N_POINTS)
|
||||
|
||||
line_raw_fft.set_data(freq_raw, amp_raw)
|
||||
line_filt_fft.set_data(freq_filt, amp_filt)
|
||||
ax3.relim()
|
||||
ax3.autoscale_view(scaley=True)
|
||||
ax4.relim()
|
||||
ax4.autoscale_view(scaley=True)
|
||||
|
||||
return lines
|
||||
|
||||
# ===================== 【7. 资源释放 & 最终汇总统计】 =====================
|
||||
def clean_resource():
|
||||
g_running.clear()
|
||||
logger.info("开始停止所有线程...")
|
||||
time.sleep(0.3)
|
||||
plt.close("all")
|
||||
logger.info("资源释放完成")
|
||||
|
||||
def main():
|
||||
logger.info("=" * 70)
|
||||
logger.info("脑电滤波测试客户端【统计逻辑优化版】启动")
|
||||
logger.info(f"服务端地址: {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
||||
logger.info(f"发包: {PKG_SEND_SHAPE}({SEND_INTERVAL*1000:.0f}ms) | 回包: {PKG_RECV_SHAPE}({RECV_INTERVAL*1000:.0f}ms)")
|
||||
logger.info(f"预热规则: {PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 包后开启统计")
|
||||
logger.info(f"收发比例: 每 {PACK_RATIO} 个发包对应 1 个回包")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# 启动ZMQ收发线程
|
||||
io_thread = threading.Thread(target=zmq_io_thread, daemon=True, name="ZMQ_IO_Thread")
|
||||
io_thread.start()
|
||||
|
||||
# 启动可视化
|
||||
fig, lines, axes = init_plot()
|
||||
ani = FuncAnimation(
|
||||
fig, update_plot,
|
||||
fargs=(lines, axes),
|
||||
interval=PLOT_REFRESH_INTERVAL,
|
||||
blit=True,
|
||||
cache_frame_data=False
|
||||
)
|
||||
|
||||
try:
|
||||
plt.show()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("收到 Ctrl+C 中断信号,准备退出")
|
||||
finally:
|
||||
# 输出最终完整汇总报表
|
||||
run_total = time.perf_counter() - stat["start_time"]
|
||||
total_send = stat["total_send"]
|
||||
total_recv = stat["total_recv"]
|
||||
v_send = stat["valid_send"]
|
||||
v_recv = stat["valid_recv"]
|
||||
t_recv = stat["theo_recv"]
|
||||
|
||||
loss_cnt = t_recv - v_recv
|
||||
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
|
||||
|
||||
logger.info(f"\n{'='*50} 最终运行汇总 {'='*50}")
|
||||
logger.info(f"总运行时长: {run_total:.1f} s")
|
||||
logger.info(f"【全局总包数】发送: {total_send} | 接收: {total_recv}")
|
||||
logger.info(f"【有效统计区间(跳过预热{PREHEAT_SEND_PACKS}包)】")
|
||||
logger.info(f" 有效发包: {v_send} | 理论应收包: {t_recv} | 实际收包: {v_recv}")
|
||||
logger.info(f" 总丢包数: {loss_cnt} | 整体丢包率: {loss_rate:.2f} %")
|
||||
logger.info(f"{'='*106}")
|
||||
|
||||
clean_resource()
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
125
logs/log.py
125
logs/log.py
@@ -1,122 +1,87 @@
|
||||
# log.py
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import inspect
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
|
||||
|
||||
# 全局配置
|
||||
console_output = IniRead('system', 'console_output', '1')
|
||||
log_level = IniRead('system', 'algo_log_level', 'INFO')
|
||||
|
||||
# 新增:日志去重缓存,key为日志内容,value为是否已打印
|
||||
log_once_cache = set()
|
||||
logger_cache = {}
|
||||
LOG_RETENTION_DAYS = 3
|
||||
|
||||
LOG_PATH_STR = IniRead('system', 'algo_log_path', "d:/Program Files/64chn_Decoder/logs")
|
||||
LOG_DIR = Path(LOG_PATH_STR)
|
||||
# 自动补全路径分隔符,创建目录(不存在则新建,避免写日志报错)
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
# 如需字符串格式路径
|
||||
LOG_DIR_STR = str(LOG_DIR) + "\\"
|
||||
LOG_FILE_PREFIX = 'algo_log_'
|
||||
|
||||
# 日志格式:时间 - 日志器名 - 级别 - 文件名:行号 - 函数名 - 日志内容
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
|
||||
def clean_old_logs():
|
||||
"""清理超过指定天数的旧日志文件"""
|
||||
try:
|
||||
if not os.path.exists(LOG_DIR):
|
||||
return
|
||||
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
|
||||
for filename in os.listdir(LOG_DIR):
|
||||
if not filename.startswith(LOG_FILE_PREFIX) or not filename.endswith('.log'):
|
||||
continue
|
||||
date_str = filename[len(LOG_FILE_PREFIX):-4]
|
||||
try:
|
||||
file_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||
if file_date < expire_date:
|
||||
file_path = os.path.join(LOG_DIR, filename)
|
||||
os.remove(file_path)
|
||||
print(f"清理过期日志: {file_path}")
|
||||
except ValueError:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"清理旧日志异常: {str(e)}")
|
||||
def init_module_logger():
|
||||
"""
|
||||
初始化指定模块的日志器
|
||||
:return: 对应模块的logger实例
|
||||
"""
|
||||
# 缓存命中则直接返回
|
||||
log_dir = './logs/' # 确保日志目录存在
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log')
|
||||
|
||||
def init_module_logger(logger_name):
|
||||
"""初始化日志器 + 清理旧日志"""
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
clean_old_logs()
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
log_file = os.path.join(LOG_DIR, f"{LOG_FILE_PREFIX}{current_date}.log")
|
||||
|
||||
if logger_name in logger_cache:
|
||||
return logger_cache[logger_name]
|
||||
|
||||
logger = logging.getLogger(logger_name)
|
||||
# 初始化logger
|
||||
logger = logging.getLogger('decoderLogger')
|
||||
logger.setLevel(log_level)
|
||||
|
||||
if logger.handlers:
|
||||
logger_cache[logger_name] = logger
|
||||
return logger
|
||||
|
||||
# 文件输出处理器
|
||||
# 设置日志轮转,最大10个文件,每个10MB
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=10*1024*1024,
|
||||
backupCount=10,
|
||||
encoding='utf-8'
|
||||
)
|
||||
formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
|
||||
|
||||
# 日志格式
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.setLevel(log_level)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 控制台输出
|
||||
if console_output:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
logger_cache[logger_name] = logger
|
||||
return logger
|
||||
|
||||
|
||||
def algo_log(content, level="INFO", record_once=False):
|
||||
"""
|
||||
日志入口函数
|
||||
自动记录:调用文件名、代码行号、所在函数
|
||||
通用日志函数,支持按模块输出到不同日志文件
|
||||
:param content: 日志内容
|
||||
:param level: 日志级别(DEBUG/INFO/WARNING/ERROR/FATAL)
|
||||
:param record_once: 是否只打印一次该日志内容,默认False
|
||||
"""
|
||||
# 回溯栈帧,获取真正调用 algo_log 的代码位置
|
||||
# f_back(1) -> algo_log 自身,f_back(2) -> 业务调用处
|
||||
frame = inspect.currentframe().f_back.f_back
|
||||
if not frame:
|
||||
file_name = "unknown"
|
||||
else:
|
||||
file_name = os.path.basename(frame.f_code.co_filename)
|
||||
# 初始化模块日志器
|
||||
logger = init_module_logger()
|
||||
|
||||
logger = init_module_logger(file_name)
|
||||
|
||||
# 单次日志去重
|
||||
# 新增:处理只打印一次的逻辑
|
||||
if record_once:
|
||||
# 生成唯一标识(可根据需要调整,比如拼接level增强唯一性)
|
||||
log_key = f"{level.upper()}_{content}"
|
||||
if log_key in log_once_cache:
|
||||
return
|
||||
log_once_cache.add(log_key)
|
||||
return # 已打印过,直接返回
|
||||
log_once_cache.add(log_key) # 未打印过,加入缓存
|
||||
|
||||
# 日志级别分发
|
||||
# 根据级别输出日志
|
||||
level_upper = level.upper()
|
||||
log_map = {
|
||||
"DEBUG": logger.debug,
|
||||
"WARNING": logger.warning,
|
||||
"ERROR": logger.error,
|
||||
"FATAL": logger.fatal,
|
||||
"INFO": logger.info
|
||||
}
|
||||
log_func = log_map.get(level_upper, logger.info)
|
||||
log_func(content)
|
||||
if level_upper == "DEBUG":
|
||||
logger.debug(content)
|
||||
elif level_upper == "WARNING":
|
||||
logger.warning(content)
|
||||
elif level_upper == "ERROR":
|
||||
logger.error(content)
|
||||
elif level_upper == "FATAL":
|
||||
logger.fatal(content)
|
||||
else: # 默认INFO级别
|
||||
logger.info(content)
|
||||
@@ -1,55 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Git Bash 中文 UTF-8 兼容配置(通用版,无报错)
|
||||
export LC_ALL=en_US.UTF-8
|
||||
export LANG=en_US.UTF-8
|
||||
|
||||
echo "========================"
|
||||
echo "Nuitka 打包脚本 - 优化稳定版"
|
||||
echo "适配:PyTorch2.0.0 + CUDA11.7 + 脑电解码项目"
|
||||
echo "========================"
|
||||
|
||||
# ===================== 自定义配置区 =====================
|
||||
PY_FILE="runDecoder.py" # 主程序文件
|
||||
OUT_DIR="dist_nuitka" # 输出文件夹
|
||||
MODEL_DIR="online_Models" # 模型文件夹
|
||||
# ========================================================
|
||||
|
||||
# 检查主文件是否存在
|
||||
if [ ! -f "${PY_FILE}" ]; then
|
||||
echo "错误:未找到主文件 ${PY_FILE},请检查路径!"
|
||||
read -n 1 -s -r -p "按任意键退出"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "开始打包:${PY_FILE}"
|
||||
echo "输出目录:${OUT_DIR}"
|
||||
|
||||
# Nuitka 核心打包命令(无错误、无冗余、全依赖)
|
||||
python -m nuitka \
|
||||
--standalone \
|
||||
--msvc=latest \
|
||||
--windows-console-mode=force \
|
||||
--module-parameter=torch-disable-jit=yes \
|
||||
--enable-plugin=no-qt \
|
||||
--include-package=numpy \
|
||||
--include-module=numpy.core._multiarray_umath \
|
||||
--include-package=scipy \
|
||||
--no-deployment-flag=self-execution \
|
||||
--include-data-dir="${MODEL_DIR}=${MODEL_DIR}" \
|
||||
--output-dir="${OUT_DIR}" \
|
||||
--remove-output \
|
||||
"${PY_FILE}"
|
||||
|
||||
# 打包结果判断
|
||||
if [ $? -eq 0 ]; then
|
||||
echo -e "\n========================"
|
||||
echo "✅ 打包成功!"
|
||||
echo "📦 产物路径:${OUT_DIR}/${PY_FILE%.py}.exe"
|
||||
echo "========================"
|
||||
else
|
||||
echo -e "\n❌ 打包失败!"
|
||||
fi
|
||||
|
||||
# Git Bash 兼容的暂停
|
||||
read -n 1 -s -r -p "按任意键退出..."
|
||||
echo
|
||||
BIN
online_Models/Model_2025-11-15-11-11-50.pth
Normal file
BIN
online_Models/Model_2025-11-15-11-11-50.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2025-11-17-16-55-25.pth
Normal file
BIN
online_Models/Model_2025-11-17-16-55-25.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2025-11-18-10-15-35.pth
Normal file
BIN
online_Models/Model_2025-11-18-10-15-35.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-01-08-14-55-10.pth
Normal file
BIN
online_Models/Model_2026-01-08-14-55-10.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-14-21-38.pth
Normal file
BIN
online_Models/Model_2026-05-26-14-21-38.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-15-26-09.pth
Normal file
BIN
online_Models/Model_2026-05-26-15-26-09.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-15-44-21.pth
Normal file
BIN
online_Models/Model_2026-05-26-15-44-21.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-16-06-24.pth
Normal file
BIN
online_Models/Model_2026-05-26-16-06-24.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-16-30-12.pth
Normal file
BIN
online_Models/Model_2026-05-26-16-30-12.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-26-16-44-52.pth
Normal file
BIN
online_Models/Model_2026-05-26-16-44-52.pth
Normal file
Binary file not shown.
BIN
online_Models/Model_2026-05-30-13-08-50.pth
Normal file
BIN
online_Models/Model_2026-05-30-13-08-50.pth
Normal file
Binary file not shown.
252
online_Models/log_result.txt
Normal file
252
online_Models/log_result.txt
Normal file
@@ -0,0 +1,252 @@
|
||||
0 0.5
|
||||
1 0.5
|
||||
2 0.375
|
||||
3 0.5
|
||||
4 0.4375
|
||||
5 0.375
|
||||
6 0.5
|
||||
7 0.5
|
||||
8 0.375
|
||||
9 0.375
|
||||
10 0.375
|
||||
11 0.375
|
||||
12 0.5
|
||||
13 0.5625
|
||||
14 0.5625
|
||||
15 0.5
|
||||
16 0.5
|
||||
17 0.5
|
||||
18 0.5
|
||||
19 0.5625
|
||||
20 0.4375
|
||||
21 0.5
|
||||
22 0.5
|
||||
23 0.375
|
||||
24 0.375
|
||||
25 0.375
|
||||
26 0.375
|
||||
27 0.375
|
||||
28 0.3125
|
||||
29 0.375
|
||||
30 0.5625
|
||||
31 0.5
|
||||
32 0.5
|
||||
33 0.5625
|
||||
34 0.5625
|
||||
35 0.3125
|
||||
36 0.3125
|
||||
37 0.3125
|
||||
38 0.375
|
||||
39 0.5625
|
||||
40 0.3125
|
||||
41 0.5625
|
||||
42 0.3125
|
||||
43 0.375
|
||||
44 0.5625
|
||||
45 0.5
|
||||
46 0.375
|
||||
47 0.375
|
||||
48 0.3125
|
||||
49 0.375
|
||||
50 0.375
|
||||
51 0.5
|
||||
52 0.5625
|
||||
53 0.375
|
||||
54 0.5625
|
||||
55 0.5625
|
||||
56 0.375
|
||||
57 0.375
|
||||
58 0.375
|
||||
59 0.5
|
||||
60 0.3125
|
||||
61 0.375
|
||||
62 0.375
|
||||
63 0.375
|
||||
64 0.375
|
||||
65 0.375
|
||||
66 0.3125
|
||||
67 0.375
|
||||
68 0.5625
|
||||
69 0.5625
|
||||
70 0.5625
|
||||
71 0.5
|
||||
72 0.5625
|
||||
73 0.375
|
||||
74 0.375
|
||||
75 0.375
|
||||
76 0.375
|
||||
77 0.375
|
||||
78 0.5
|
||||
79 0.375
|
||||
80 0.375
|
||||
81 0.5
|
||||
82 0.375
|
||||
83 0.375
|
||||
84 0.375
|
||||
85 0.375
|
||||
86 0.3125
|
||||
87 0.375
|
||||
88 0.375
|
||||
89 0.5
|
||||
90 0.375
|
||||
91 0.4375
|
||||
92 0.3125
|
||||
93 0.3125
|
||||
94 0.375
|
||||
95 0.375
|
||||
96 0.375
|
||||
97 0.375
|
||||
98 0.3125
|
||||
99 0.4375
|
||||
100 0.375
|
||||
101 0.375
|
||||
102 0.375
|
||||
103 0.3125
|
||||
104 0.5625
|
||||
105 0.5
|
||||
106 0.5625
|
||||
107 0.5625
|
||||
108 0.5
|
||||
109 0.3125
|
||||
110 0.5625
|
||||
111 0.5625
|
||||
112 0.5
|
||||
113 0.3125
|
||||
114 0.5
|
||||
115 0.3125
|
||||
116 0.375
|
||||
117 0.3125
|
||||
118 0.3125
|
||||
119 0.3125
|
||||
120 0.3125
|
||||
121 0.375
|
||||
122 0.375
|
||||
123 0.375
|
||||
124 0.375
|
||||
125 0.3125
|
||||
126 0.375
|
||||
127 0.375
|
||||
128 0.375
|
||||
129 0.375
|
||||
130 0.5625
|
||||
131 0.375
|
||||
132 0.5
|
||||
133 0.3125
|
||||
134 0.3125
|
||||
135 0.3125
|
||||
136 0.375
|
||||
137 0.5
|
||||
138 0.3125
|
||||
139 0.375
|
||||
140 0.3125
|
||||
141 0.3125
|
||||
142 0.3125
|
||||
143 0.5625
|
||||
144 0.3125
|
||||
145 0.375
|
||||
146 0.5
|
||||
147 0.5
|
||||
148 0.375
|
||||
149 0.4375
|
||||
150 0.5
|
||||
151 0.3125
|
||||
152 0.375
|
||||
153 0.375
|
||||
154 0.375
|
||||
155 0.3125
|
||||
156 0.375
|
||||
157 0.4375
|
||||
158 0.4375
|
||||
159 0.375
|
||||
160 0.375
|
||||
161 0.3125
|
||||
162 0.375
|
||||
163 0.375
|
||||
164 0.375
|
||||
165 0.3125
|
||||
166 0.3125
|
||||
167 0.3125
|
||||
168 0.375
|
||||
169 0.3125
|
||||
170 0.3125
|
||||
171 0.3125
|
||||
172 0.375
|
||||
173 0.3125
|
||||
174 0.3125
|
||||
175 0.5
|
||||
176 0.3125
|
||||
177 0.375
|
||||
178 0.375
|
||||
179 0.3125
|
||||
180 0.3125
|
||||
181 0.3125
|
||||
182 0.3125
|
||||
183 0.5625
|
||||
184 0.5625
|
||||
185 0.3125
|
||||
186 0.5
|
||||
187 0.5
|
||||
188 0.5625
|
||||
189 0.5
|
||||
190 0.5625
|
||||
191 0.5625
|
||||
192 0.5625
|
||||
193 0.5
|
||||
194 0.5
|
||||
195 0.5625
|
||||
196 0.5625
|
||||
197 0.5625
|
||||
198 0.5625
|
||||
199 0.5
|
||||
200 0.5625
|
||||
201 0.5625
|
||||
202 0.375
|
||||
203 0.375
|
||||
204 0.375
|
||||
205 0.375
|
||||
206 0.375
|
||||
207 0.5
|
||||
208 0.5
|
||||
209 0.5625
|
||||
210 0.5625
|
||||
211 0.5625
|
||||
212 0.3125
|
||||
213 0.5
|
||||
214 0.5
|
||||
215 0.5625
|
||||
216 0.5
|
||||
217 0.5
|
||||
218 0.5
|
||||
219 0.5625
|
||||
220 0.5
|
||||
221 0.4375
|
||||
222 0.5
|
||||
223 0.5
|
||||
224 0.4375
|
||||
225 0.5
|
||||
226 0.4375
|
||||
227 0.5
|
||||
228 0.5
|
||||
229 0.375
|
||||
230 0.375
|
||||
231 0.3125
|
||||
232 0.375
|
||||
233 0.375
|
||||
234 0.375
|
||||
235 0.5625
|
||||
236 0.5625
|
||||
237 0.5625
|
||||
238 0.5625
|
||||
239 0.5625
|
||||
240 0.5
|
||||
241 0.5
|
||||
242 0.5
|
||||
243 0.5625
|
||||
244 0.5625
|
||||
245 0.375
|
||||
246 0.375
|
||||
247 0.375
|
||||
248 0.3125
|
||||
249 0.375
|
||||
The average accuracy is: 0.42675
|
||||
The best accuracy is: 0.5625
|
||||
@@ -1,52 +0,0 @@
|
||||
Bottleneck==1.4.2
|
||||
brotlicffi==1.2.0.0
|
||||
certifi==2026.5.20
|
||||
cffi==2.0.0
|
||||
charset-normalizer==3.4.4
|
||||
contourpy==1.3.2
|
||||
cycler==0.12.1
|
||||
einops==0.8.2
|
||||
filelock==3.20.3
|
||||
fonttools==4.63.0
|
||||
gmpy2==2.2.2
|
||||
idna==3.11
|
||||
Jinja2==3.1.6
|
||||
joblib==1.5.3
|
||||
kiwisolver==1.5.0
|
||||
MarkupSafe==3.0.2
|
||||
matplotlib==3.10.9
|
||||
mkl_fft==1.3.11
|
||||
mkl_random==1.2.8
|
||||
mkl-service==2.5.2
|
||||
mpmath==1.3.0
|
||||
networkx==3.4.2
|
||||
Nuitka==4.1.1
|
||||
numexpr==2.14.1
|
||||
numpy==1.24.3
|
||||
packaging==26.0
|
||||
pandas==2.3.3
|
||||
pillow==12.2.0
|
||||
pip==26.0.1
|
||||
pycparser==3.0
|
||||
pyparsing==3.3.2
|
||||
pyserial==3.5
|
||||
PySocks==1.7.1
|
||||
python-dateutil==2.9.0.post0
|
||||
pytz==2026.1.post1
|
||||
pyzmq==27.1.0
|
||||
requests==2.33.1
|
||||
scikit-learn==1.7.1
|
||||
scipy==1.15.3
|
||||
setuptools==82.0.1
|
||||
six==1.17.0
|
||||
sympy==1.14.0
|
||||
threadpoolctl==3.5.0
|
||||
torch==2.0.0
|
||||
torchaudio==2.0.0
|
||||
torchsummary==1.5.1
|
||||
torchvision==0.15.0
|
||||
typing_extensions==4.15.0
|
||||
tzdata==2026.2
|
||||
urllib3==2.7.0
|
||||
wheel==0.46.3
|
||||
win_inet_pton==1.1.0
|
||||
@@ -1,38 +1,37 @@
|
||||
# import matplotlib
|
||||
# matplotlib.use('Agg')
|
||||
# import argparse
|
||||
# import sys
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from Decoder import Decoder_main
|
||||
from PubLibrary.RunOnce import is_program_running
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
from logs.log import algo_log
|
||||
|
||||
def get_device_info(device_type):
|
||||
|
||||
|
||||
section = f'device_type_{device_type}'
|
||||
device_info = {
|
||||
'sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250,
|
||||
'frame_points': int(IniRead(section, 'frame_points')) if IniRead(section, 'frame_points') is not None else 5,
|
||||
'channel_nums': int(IniRead(section, 'channel_nums')) if IniRead(section, 'channel_nums') is not None else 66,
|
||||
'channel_names': IniRead(section, 'channel_names') if IniRead(section, 'channel_names') is not None else None,
|
||||
'channel_index': IniRead(section, 'channel_index') if IniRead(section, 'channel_index') is not None else None,
|
||||
'device_sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250,
|
||||
|
||||
''
|
||||
}
|
||||
|
||||
return device_info
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not is_program_running():
|
||||
# 解析命令行参数
|
||||
# parser = argparse.ArgumentParser(description="EEG Decoder Application")
|
||||
# parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type")
|
||||
parser = argparse.ArgumentParser(description="EEG Decoder Application")
|
||||
parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type")
|
||||
# parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
|
||||
# parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
|
||||
# parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
|
||||
# parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
|
||||
# args = parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
device_info= get_device_info(args.device_type)
|
||||
|
||||
|
||||
decoder = Decoder_main(device_info=device_info)
|
||||
# decoder.connect(
|
||||
# device_type=args.device_type,
|
||||
# device_host=args.device_host,
|
||||
@@ -41,10 +40,6 @@ if __name__ == "__main__":
|
||||
# upper_port=args.upper_port
|
||||
# )
|
||||
|
||||
device_info= get_device_info(1)
|
||||
algo_log(f"device_info: {device_info}", level="DEBUG")
|
||||
decoder = Decoder_main(device_info=device_info)
|
||||
|
||||
try:
|
||||
decoder.start()
|
||||
while not decoder.zmqServer.IsExitApp:
|
||||
|
||||
@@ -1,306 +0,0 @@
|
||||
"""
|
||||
MI_headless.py
|
||||
无界面版 MI 运动想象范式通讯流程模拟脚本。
|
||||
复现 MI_main.py 的完整指令序列(train 0/1, rest, predict, saveData),
|
||||
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
|
||||
|
||||
启动顺序:
|
||||
1. runDecoder.py
|
||||
2. datamock.py
|
||||
3. MI_headless.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import zmq
|
||||
import numpy as np
|
||||
import ast
|
||||
from datetime import datetime
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
|
||||
personname = 'demo'
|
||||
session = '01'
|
||||
|
||||
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
|
||||
|
||||
|
||||
# ========== ZMQ 结果接收服务 ==========
|
||||
class ZmqResultServer(threading.Thread):
|
||||
def __init__(self, port=8088):
|
||||
threading.Thread.__init__(self)
|
||||
self.port = port
|
||||
self.running = True
|
||||
self.energy = 0
|
||||
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
|
||||
self.ChoosenNum = -1
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.ROUTER)
|
||||
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
|
||||
self.daemon = True
|
||||
self.trial_idx = 0
|
||||
|
||||
def run(self):
|
||||
print(f"[Server] UpperHost_Server listening on {self.port}")
|
||||
while self.running:
|
||||
try:
|
||||
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||
if len(frames) < 3:
|
||||
continue
|
||||
message = json.loads(frames[2].decode('utf-8'))
|
||||
method = message.get('method')
|
||||
params = message.get('params')
|
||||
if method == 'energy':
|
||||
self.energy = params
|
||||
elif method == 'paradigm':
|
||||
self.paradigm = params
|
||||
print(f"[Server] paradigm -> {params}")
|
||||
elif method == 'result':
|
||||
self.ChoosenNum = params
|
||||
self.trial_idx += 1
|
||||
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
|
||||
except zmq.Again:
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
print(f"[Server] error: {e}")
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
|
||||
|
||||
# ========== ZMQ 命令发送客户端 ==========
|
||||
class ZmqCmdClient:
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.DEALER)
|
||||
# PUSH socket 用于向 datamock.py 发送标签命令
|
||||
self._label_sock = self.context.socket(zmq.PUSH)
|
||||
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
|
||||
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
|
||||
|
||||
def connect(self):
|
||||
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||
print(f"[Client] connected to {self.host}:{self.port}")
|
||||
|
||||
def start_recv_thread(self, result_server):
|
||||
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
|
||||
self._result_server = result_server
|
||||
self._stop_recv = threading.Event()
|
||||
|
||||
def _recv_loop():
|
||||
while not self._stop_recv.is_set():
|
||||
try:
|
||||
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||
# DEALER 收到的格式: [b'', json_bytes]
|
||||
data_bytes = frames[-1]
|
||||
message = json.loads(data_bytes.decode('utf-8'))
|
||||
method = message.get('method')
|
||||
params = message.get('params')
|
||||
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
print(f"[{ts}] [CmdClient] recv: {method}={params}")
|
||||
if method == 'paradigm':
|
||||
self._result_server.paradigm = params
|
||||
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
|
||||
elif method == 'result':
|
||||
self._result_server.ChoosenNum = params
|
||||
self._result_server.trial_idx += 1
|
||||
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
|
||||
elif method == 'energy':
|
||||
self._result_server.energy = params
|
||||
except zmq.Again:
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
print(f"[CmdClient recv] error: {e}")
|
||||
time.sleep(0.01)
|
||||
|
||||
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
|
||||
self._recv_thread.start()
|
||||
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
|
||||
|
||||
def stop_recv_thread(self):
|
||||
if hasattr(self, '_stop_recv'):
|
||||
self._stop_recv.set()
|
||||
|
||||
def _send_label(self, label_value):
|
||||
"""向 datamock.py 发送标签命令"""
|
||||
try:
|
||||
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
|
||||
except Exception as e:
|
||||
print(f"[Client] label send error: {e}")
|
||||
|
||||
def send_data(self, method, params):
|
||||
msg = {'method': method, 'params': params}
|
||||
try:
|
||||
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
print(f"[{ts}] send_data: {method}={params}")
|
||||
# 根据 train/predict 命令向 datamock 发送标签
|
||||
if method == 'train':
|
||||
if params == 0:
|
||||
self._send_label(1)
|
||||
print(f"[Label] train 0 -> datamock label=1")
|
||||
elif params == 1:
|
||||
self._send_label(2)
|
||||
print(f"[Label] train 1 -> datamock label=2")
|
||||
elif method == 'predict':
|
||||
self._send_label(99)
|
||||
print(f"[Label] predict -> datamock label=99")
|
||||
except Exception as e:
|
||||
print(f"[Client] send error: {e}")
|
||||
|
||||
|
||||
# ========== 主流程 ==========
|
||||
def run_headless():
|
||||
server = ZmqResultServer(port=8088)
|
||||
server.start()
|
||||
|
||||
_dh = str(IniRead('system', 'Decoder_Host'))
|
||||
_dp = int(IniRead('system', 'Decoder_Port'))
|
||||
client = ZmqCmdClient(_dh, _dp)
|
||||
client.connect()
|
||||
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
|
||||
|
||||
time.sleep(1) # 等待连接建立
|
||||
client.send_data('decoderClass', 'mi')
|
||||
time.sleep(4) # 等待 zmqServer 排空启动积压包(datamock 提前连接会积压 ~3s 数据)
|
||||
|
||||
# MI_IntervalEpoch = [0.5, 4.5],trial时长 = 4.5-0.5 = 4.0s
|
||||
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||
_trial_sec = float(_mi_iv[1] - _mi_iv[0]) # 4.0s
|
||||
_margin = 1.0
|
||||
train_time = max(5.0, _trial_sec + _margin) # 训练刺激时长(与 MI_main.py 保持一致)
|
||||
|
||||
# MI epoch latency = interval_epoch[1] // 5 = (4.5*250)//5 = 225包 × 20ms = 4.5s
|
||||
# train_latency = 225包(MI中 train_latency == latency)
|
||||
# 在 train_time 后需再等 epoch_wait 秒,decoder 才能完成 epoch 采集
|
||||
epoch_wait = _mi_iv[1] / _mi_iv[1] * (_mi_iv[1] * 250 // 5) * 0.02 # = latency * 20ms
|
||||
# 更直接的计算:latency = interval_epoch[1] // 5 = int(4.5*250)//5 = 225,225*0.02 = 4.5s
|
||||
epoch_wait = (int(_mi_iv[1] * 250) // 5) * 0.02 # 4.5s
|
||||
|
||||
# predict epoch wait(与 train 相同,MI中 latency == train_latency)
|
||||
predict_epoch_wait = epoch_wait # 4.5s
|
||||
|
||||
test_time = 7.0 # 预测窗口时长(与 MI_main.py 保持一致)
|
||||
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
|
||||
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
|
||||
rest_time = float(IniRead('system', 'Rest_time'))
|
||||
|
||||
num_blocks = int(IniRead('system', 'Num_blocks'))
|
||||
num_trials = int(IniRead('system', 'Num_trials'))
|
||||
|
||||
trained = 0
|
||||
Num_Total = 0
|
||||
Num_Success = 0
|
||||
user_choice = []
|
||||
|
||||
print("=" * 50)
|
||||
print("[Headless] 开始运行 MI 通讯流程(无界面)")
|
||||
print(f" MI_IntervalEpoch={_mi_iv}, trial_sec={_trial_sec:.2f}s")
|
||||
print(f" train_time={train_time:.2f}s, epoch_wait={epoch_wait:.2f}s")
|
||||
print(f" test_time={test_time:.2f}s, predict_epoch_wait={predict_epoch_wait:.2f}s")
|
||||
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# -------- 个体校准阶段 --------
|
||||
print("\n[Phase] 个体校准阶段 (paradigm=0)")
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(1)
|
||||
|
||||
while server.paradigm == 0:
|
||||
# 左侧 MI 刺激(train 0,label=1)
|
||||
print(f"\n[Train] 左侧 MI 刺激 (train 0) trained={trained}")
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(0.5) # ding 提示后等待
|
||||
|
||||
client.send_data('train', 0)
|
||||
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(1.0) # 类间休息
|
||||
|
||||
# 空闲态样本采集(train 1,label=2)
|
||||
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
|
||||
client.send_data('train', 1)
|
||||
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(1.0) # 类间休息
|
||||
|
||||
# 个体校准阶段结束
|
||||
print("\n[Phase] 个体校准结束,等待模型训练 (paradigm=2) ...")
|
||||
trained = 0
|
||||
time.sleep(1)
|
||||
|
||||
# 等待模型训练完成 (paradigm=2 -> paradigm=1)
|
||||
while server.paradigm == 2:
|
||||
print("[Phase] 等待模型训练完成 ...")
|
||||
time.sleep(0.5)
|
||||
|
||||
# -------- 康复训练阶段 --------
|
||||
while server.paradigm == 1:
|
||||
print("\n[Phase] 康复训练阶段 (paradigm=1)")
|
||||
for block_idx in range(num_blocks):
|
||||
print(f"\n [Block {block_idx+1}/{num_blocks}]")
|
||||
time.sleep(10) # 每轮开始前等待
|
||||
|
||||
for trial_idx in range(num_trials):
|
||||
print(f" [Trial {trial_idx+1}/{num_trials}]")
|
||||
|
||||
time.sleep(0.5) # ding 提示
|
||||
server.ChoosenNum = -1
|
||||
|
||||
# 开始预测
|
||||
# MI predict epoch latency = 225包 × 20ms = 4.5s,需额外等待 epoch 完成
|
||||
client.send_data('predict', 1)
|
||||
t_start = time.perf_counter()
|
||||
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
|
||||
if server.ChoosenNum >= 0:
|
||||
Num_Total += 1
|
||||
user_choice.append(server.ChoosenNum)
|
||||
if server.ChoosenNum == 0:
|
||||
Num_Success += 1
|
||||
rest_time = right_rehabilitation
|
||||
elif server.ChoosenNum == 1:
|
||||
rest_time = fault_rehabilitation
|
||||
break
|
||||
time.sleep(0.02)
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(0.5)
|
||||
time.sleep(rest_time)
|
||||
server.ChoosenNum = -1
|
||||
|
||||
# 训练结束
|
||||
print("\n[Phase] 康复训练结束")
|
||||
break # 退出康复训练循环
|
||||
|
||||
# 统计结果
|
||||
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
|
||||
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
|
||||
print(f"[Result] user_choice={user_choice}")
|
||||
break # 完成一个完整流程后退出
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n[Headless] 用户中断")
|
||||
finally:
|
||||
client.send_data('predict', 2) # 关闭系统
|
||||
client.send_data('saveData', 0)
|
||||
server.stop()
|
||||
print("[Headless] 已发送关闭指令,退出。")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_headless()
|
||||
@@ -1,301 +0,0 @@
|
||||
"""
|
||||
ssmvep_headless.py
|
||||
无界面版 SSMVEP 范式通讯流程模拟脚本。
|
||||
复现 ssmvep_main.py 的完整指令序列(train 0/1/2, rest, predict, saveData),
|
||||
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
|
||||
|
||||
启动顺序:
|
||||
1. runDecoder.py
|
||||
2. datamock.py
|
||||
3. ssmvep_headless.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import zmq
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
|
||||
personname = 'demo'
|
||||
session = '01'
|
||||
|
||||
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
|
||||
|
||||
|
||||
# ========== ZMQ 结果接收服务 ==========
|
||||
class ZmqResultServer(threading.Thread):
|
||||
def __init__(self, port=8088):
|
||||
threading.Thread.__init__(self)
|
||||
self.port = port
|
||||
self.running = True
|
||||
self.energy = 0
|
||||
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
|
||||
self.ChoosenNum = -1
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.ROUTER)
|
||||
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
|
||||
self.daemon = True
|
||||
self.trial_idx = 0
|
||||
|
||||
def run(self):
|
||||
print(f"[Server] UpperHost_Server listening on {self.port}")
|
||||
while self.running:
|
||||
try:
|
||||
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||
if len(frames) < 3:
|
||||
continue
|
||||
message = json.loads(frames[2].decode('utf-8'))
|
||||
method = message.get('method')
|
||||
params = message.get('params')
|
||||
if method == 'energy':
|
||||
self.energy = params
|
||||
elif method == 'paradigm':
|
||||
self.paradigm = params
|
||||
print(f"[Server] paradigm -> {params}")
|
||||
elif method == 'result':
|
||||
self.ChoosenNum = params
|
||||
self.trial_idx += 1
|
||||
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
|
||||
except zmq.Again:
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
print(f"[Server] error: {e}")
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
|
||||
|
||||
# ========== ZMQ 命令发送客户端 ==========
|
||||
class ZmqCmdClient:
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.DEALER)
|
||||
# PUSH socket 用于向 datamock.py 发送标签命令
|
||||
self._label_sock = self.context.socket(zmq.PUSH)
|
||||
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
|
||||
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
|
||||
|
||||
def connect(self):
|
||||
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||
print(f"[Client] connected to {self.host}:{self.port}")
|
||||
|
||||
def start_recv_thread(self, result_server):
|
||||
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
|
||||
self._result_server = result_server
|
||||
self._stop_recv = threading.Event()
|
||||
|
||||
def _recv_loop():
|
||||
while not self._stop_recv.is_set():
|
||||
try:
|
||||
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||
# DEALER 收到的格式: [b'', json_bytes]
|
||||
data_bytes = frames[-1]
|
||||
message = json.loads(data_bytes.decode('utf-8'))
|
||||
method = message.get('method')
|
||||
params = message.get('params')
|
||||
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
print(f"[{ts}] [CmdClient] recv: {method}={params}")
|
||||
if method == 'paradigm':
|
||||
self._result_server.paradigm = params
|
||||
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
|
||||
elif method == 'result':
|
||||
self._result_server.ChoosenNum = params
|
||||
self._result_server.trial_idx += 1
|
||||
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
|
||||
elif method == 'energy':
|
||||
self._result_server.energy = params
|
||||
except zmq.Again:
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
print(f"[CmdClient recv] error: {e}")
|
||||
time.sleep(0.01)
|
||||
|
||||
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
|
||||
self._recv_thread.start()
|
||||
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
|
||||
|
||||
def stop_recv_thread(self):
|
||||
if hasattr(self, '_stop_recv'):
|
||||
self._stop_recv.set()
|
||||
|
||||
def _send_label(self, label_value):
|
||||
"""向 datamock.py 发送标签命令"""
|
||||
try:
|
||||
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
|
||||
except Exception as e:
|
||||
print(f"[Client] label send error: {e}")
|
||||
|
||||
def send_data(self, method, params):
|
||||
msg = {'method': method, 'params': params}
|
||||
try:
|
||||
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
print(f"[{ts}] send_data: {method}={params}")
|
||||
# 根据 train/predict 命令向 datamock 发送标签
|
||||
if method == 'train':
|
||||
if params == 0:
|
||||
self._send_label(1)
|
||||
print(f"[Label] train 0 -> datamock label=1")
|
||||
elif params == 1:
|
||||
self._send_label(2)
|
||||
print(f"[Label] train 1 -> datamock label=2")
|
||||
elif method == 'predict':
|
||||
self._send_label(99)
|
||||
print(f"[Label] predict -> datamock label=99")
|
||||
except Exception as e:
|
||||
print(f"[Client] send error: {e}")
|
||||
|
||||
|
||||
# ========== 主流程 ==========
|
||||
def run_headless():
|
||||
server = ZmqResultServer(port=8088)
|
||||
server.start()
|
||||
|
||||
_dh = str(IniRead('system', 'Decoder_Host'))
|
||||
_dp = int(IniRead('system', 'Decoder_Port'))
|
||||
client = ZmqCmdClient(_dh, _dp)
|
||||
client.connect()
|
||||
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
|
||||
|
||||
time.sleep(1) # 等待连接建立
|
||||
client.send_data('decoderClass', 'ssmvep')
|
||||
|
||||
train_time = 2.5 # 每轮训练刺激时长 (s)
|
||||
test_time = 2.5 # 每轮测试刺激时长 (s)
|
||||
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
|
||||
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
|
||||
rest_time = float(IniRead('system', 'Rest_time'))
|
||||
|
||||
num_blocks = int(IniRead('system', 'Num_blocks'))
|
||||
num_trials = int(IniRead('system', 'Num_trials'))
|
||||
|
||||
position = [0, 1]
|
||||
truePos_seq = position * int(num_trials / len(position))
|
||||
truePos_seq = np.random.permutation(truePos_seq).tolist()
|
||||
user_choice = []
|
||||
|
||||
os.makedirs('EEGFiles', exist_ok=True)
|
||||
seq_file_path = f'EEGFiles/pos_seq_{personname}{session}_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||
seq_info = {
|
||||
'position': position,
|
||||
'sequence': truePos_seq,
|
||||
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
with open(seq_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(seq_info, f, ensure_ascii=False, indent=2)
|
||||
|
||||
trained = 0
|
||||
Num_Total = 0
|
||||
Num_Success = 0
|
||||
|
||||
print("=" * 50)
|
||||
print("[Headless] 开始运行 SSMVEP 通讯流程(无界面)")
|
||||
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
|
||||
print(f" train_time={train_time}s, test_time={test_time}s")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# -------- 个体校准阶段 --------
|
||||
print("\n[Phase] 个体校准阶段 (paradigm=0)")
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(1)
|
||||
|
||||
# epoch完成需要的额外等待时间:train_latency=120包×20ms=2.4s
|
||||
# 在train_time后需再等epoch_wait秒,decoder才能完成epoch采集并取出数据
|
||||
epoch_wait = 2.4 # 秒,与train_latency对应
|
||||
|
||||
while server.paradigm == 0:
|
||||
# 左腿刺激
|
||||
print(f"\n[Train] 左腿刺激 (train 0) trained={trained}")
|
||||
client.send_data('train', 0)
|
||||
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(max(0, abs(fault_rehabilitation - train_time) - epoch_wait))
|
||||
|
||||
# 右腿刺激
|
||||
print(f"\n[Train] 右腿刺激 (train 1) trained={trained}")
|
||||
client.send_data('train', 1)
|
||||
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(max(0, fault_rehabilitation - epoch_wait))
|
||||
|
||||
# 个体校准阶段结束
|
||||
print("\n[Phase] 个体校准结束,等待 paradigm=1 ...")
|
||||
trained = 0
|
||||
time.sleep(1)
|
||||
|
||||
# -------- 康复训练阶段 --------
|
||||
while server.paradigm == 1:
|
||||
print("\n[Phase] 康复训练阶段 (paradigm=1)")
|
||||
for block_idx in range(num_blocks):
|
||||
print(f"\n [Block {block_idx+1}/{num_blocks}]")
|
||||
time.sleep(10) # 每轮开始前等待
|
||||
|
||||
for trial_idx in range(num_trials):
|
||||
true_position = truePos_seq[trial_idx]
|
||||
print(f" [Trial {trial_idx+1}/{num_trials}] true_pos={true_position}")
|
||||
|
||||
time.sleep(0.5) # 提示 + 叮声
|
||||
server.ChoosenNum = -1
|
||||
|
||||
# 开始测试
|
||||
# predict epoch latency = 115包×20ms = 2.3s,需额外等待epoch完成
|
||||
predict_epoch_wait = 2.3 # 秒,与predict latency=115包对应
|
||||
client.send_data('predict', 1)
|
||||
t_start = time.perf_counter()
|
||||
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
|
||||
if server.ChoosenNum >= 0:
|
||||
Num_Total += 1
|
||||
user_choice.append(server.ChoosenNum)
|
||||
if server.ChoosenNum in [0, 1]:
|
||||
Num_Success += 1
|
||||
rest_time = right_rehabilitation
|
||||
break
|
||||
time.sleep(0.02)
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
time.sleep(0.5)
|
||||
time.sleep(rest_time)
|
||||
server.ChoosenNum = -1
|
||||
|
||||
# 训练结束
|
||||
print("\n[Phase] 康复训练结束")
|
||||
break # 退出康复训练循环
|
||||
|
||||
# 统计结果
|
||||
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
|
||||
expected_seq = truePos_seq * num_blocks
|
||||
min_len = min(len(user_choice), len(expected_seq))
|
||||
same_count = sum(1 for a, b in zip(user_choice[:min_len], expected_seq[:min_len]) if a == b)
|
||||
true_accuracy = same_count / min_len if min_len > 0 else 0
|
||||
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
|
||||
print(f"[Result] TrueAcc={true_accuracy:.3f} ({same_count}/{min_len})")
|
||||
break # 完成一个完整流程后退出
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n[Headless] 用户中断")
|
||||
finally:
|
||||
client.send_data('predict', 2) # 关闭系统
|
||||
client.send_data('saveData', 0)
|
||||
server.stop()
|
||||
print("[Headless] 已发送关闭指令,退出。")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_headless()
|
||||
@@ -1,364 +0,0 @@
|
||||
import time
|
||||
|
||||
from psychopy import visual, core, logging # import some libraries from PsychoPy
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
# LAB STREAMING LAYER1
|
||||
from pylsl import StreamInfo, StreamOutlet
|
||||
from psychopy import event
|
||||
import numpy as np
|
||||
from DecoderDW.Server import TCPServer
|
||||
from DecoderDW.Client import TCPClient
|
||||
# import subprocess
|
||||
|
||||
# ----------------------
|
||||
# constants
|
||||
# size of the window
|
||||
WINWIDTH = 1920
|
||||
WINHEIGHT = 1080
|
||||
REFRESH_RATE = 144
|
||||
|
||||
|
||||
|
||||
def get_keypress():
|
||||
keys = event.getKeys()
|
||||
if keys:
|
||||
return keys[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def shutdown(win,client):
|
||||
client.send_data('saveData', 0)
|
||||
client.send_data('predict',2)
|
||||
win.close()
|
||||
core.quit()
|
||||
|
||||
|
||||
# end of configuration
|
||||
# ----------------------
|
||||
|
||||
def generate_square_wave(frequency, sampling_rate=REFRESH_RATE, duration=5):
|
||||
"""
|
||||
生成方波序列
|
||||
|
||||
参数:
|
||||
frequency (float): 频率(Hz)
|
||||
sampling_rate (int): 采样率(Hz),应与屏幕刷新率一致
|
||||
duration (float): 时长(秒)
|
||||
|
||||
返回:
|
||||
square_wave (list): 方波序列
|
||||
"""
|
||||
# 计算总点数
|
||||
n_points = int(duration * sampling_rate)
|
||||
|
||||
# 生成时间序列
|
||||
time = np.linspace(0, duration, n_points, endpoint=False)
|
||||
|
||||
# 生成正弦波数据
|
||||
sin_wave = np.sin(2 * np.pi * frequency * time)
|
||||
# 生成方波数据
|
||||
square_wave = np.where(sin_wave >= 0, 1, 0)
|
||||
|
||||
return square_wave.tolist()
|
||||
|
||||
|
||||
# 启动一个进程,不等待其完成
|
||||
import os
|
||||
if __name__ == "__main__":
|
||||
# ----------------------------------------------------------------------------------
|
||||
# main window settings
|
||||
main_win = visual.Window(size=(WINWIDTH, WINHEIGHT), units='height', screen=0, fullscr=False,
|
||||
gammaErrorPolicy='warn', color=(0.7, 0.7, 0.7))
|
||||
print('starting 1')
|
||||
# Set up LabStreamingLayer stream.
|
||||
info = StreamInfo(name='psychopy_stimuli', type='Markers', channel_count=1, channel_format='string',
|
||||
source_id='psychopy_stimuli_001')
|
||||
outlet = StreamOutlet(info) # Broadcast the stream.
|
||||
|
||||
imageStim1 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim1 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(-600, 30))
|
||||
|
||||
imageStim2 = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim2 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(0, 30))
|
||||
|
||||
imageStim3 = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim3 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(600, 30))
|
||||
imageStim4 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim4 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(-600, -470))
|
||||
imageStim5 = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim5 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(0, -470))
|
||||
imageStim6 = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy.jpg')
|
||||
txtStim6 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||
italic=False, pos=(600, -470))
|
||||
imageStim1red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||
imageStim2red = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||
imageStim3red = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||
imageStim4red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||
imageStim5red = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||
imageStim6red = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||
|
||||
|
||||
frequencies = [25,26,27,28,29,30] #[9,10,11,12,13,14] #[30,31,32,33,34,35] [25,26,27,28,29,30]
|
||||
# 生成方波数据
|
||||
square_wave_9 = generate_square_wave(frequencies[0], REFRESH_RATE, 5)
|
||||
square_wave_11 = generate_square_wave(frequencies[1], REFRESH_RATE, 5)
|
||||
square_wave_12 = generate_square_wave(frequencies[2], REFRESH_RATE, 5)
|
||||
square_wave_13 = generate_square_wave(frequencies[3], REFRESH_RATE, 5)
|
||||
square_wave_14 = generate_square_wave(frequencies[4], REFRESH_RATE, 5)
|
||||
square_wave_15 = generate_square_wave(frequencies[5], REFRESH_RATE, 5)
|
||||
|
||||
# 创建刺激对象列表,便于管理
|
||||
image_stims = [imageStim1, imageStim2, imageStim3, imageStim4, imageStim5, imageStim6]
|
||||
txt_stims = [txtStim1, txtStim2, txtStim3, txtStim4, txtStim5, txtStim6]
|
||||
square_waves = [square_wave_9, square_wave_11, square_wave_12, square_wave_13, square_wave_14, square_wave_15]
|
||||
|
||||
time.sleep(2)
|
||||
# grating.color = 'black'
|
||||
server = TCPServer()
|
||||
server.start()
|
||||
client = TCPClient('127.0.0.1', 8099)
|
||||
client.connect()
|
||||
print('Connected decoder_main')
|
||||
# client.send_data('impedance', 1)
|
||||
# time.sleep(20)
|
||||
# client.send_data('impedance', 2)
|
||||
client.send_data('targetFreqs', frequencies) # 使用frequencies变量,确保与刺激频率一致
|
||||
time.sleep(1)
|
||||
# 开启全程数据保存到 EEGFiles
|
||||
client.send_data('saveData',1)
|
||||
# client.send_data('impedance',1)
|
||||
|
||||
|
||||
|
||||
# 实验参数
|
||||
repeats = 3
|
||||
seq_freq = frequencies * repeats
|
||||
seq_freq = np.random.permutation(seq_freq).tolist()
|
||||
num_trials = len(seq_freq) # 总试验次数, 6*6=36
|
||||
trial_count = 0
|
||||
|
||||
# 在线解码精度计算
|
||||
online_results = [] # 存储每个trial的解码结果
|
||||
correct_predictions = 0 # 正确预测计数
|
||||
|
||||
# 保存序列信息
|
||||
seq_info = {
|
||||
'total_trials': num_trials,
|
||||
'frequencies': frequencies,
|
||||
'sequence': seq_freq,
|
||||
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
# 保存序列信息到文件
|
||||
import json
|
||||
seq_file_path = f'EEGFiles/sequence_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||
with open(seq_file_path, 'a', encoding='utf-8') as f:
|
||||
json.dump(seq_info, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
#========================Trials Started======================#
|
||||
while trial_count < num_trials:
|
||||
# 从序列中获取当前试验的目标频率
|
||||
target_freq = seq_freq[trial_count]
|
||||
target_freq_index = frequencies.index(target_freq)
|
||||
print(f'Trials {trial_count + 1}/{num_trials} - Target Frequency: {target_freq}Hz (Label: {target_freq_index + 1})')
|
||||
|
||||
# Stage 1: Cue Stage
|
||||
# print('Cue Stage: The target frequency is in Red')
|
||||
client.send_data('setLabelAndTrialInfo', {
|
||||
'label': 0,
|
||||
'trial_info': {
|
||||
'trial': trial_count + 1,
|
||||
'phase': 'cue',
|
||||
'target_freq': target_freq
|
||||
}
|
||||
})
|
||||
|
||||
for frameN in range(int(1 * REFRESH_RATE)): # 1秒提示
|
||||
key_press = get_keypress()
|
||||
if key_press in ['q']:
|
||||
shutdown(main_win, client)
|
||||
|
||||
# 显示所有刺激,目标刺激为红色
|
||||
for i, stim in enumerate(image_stims):
|
||||
if i == target_freq_index:
|
||||
# 目标刺激显示红色
|
||||
if i == 0:
|
||||
imageStim1red.draw()
|
||||
elif i == 1:
|
||||
imageStim2red.draw()
|
||||
elif i == 2:
|
||||
imageStim3red.draw()
|
||||
elif i == 3:
|
||||
imageStim4red.draw()
|
||||
elif i == 4:
|
||||
imageStim5red.draw()
|
||||
elif i == 5:
|
||||
imageStim6red.draw()
|
||||
else:
|
||||
# 其他刺激显示正常颜色
|
||||
stim.draw()
|
||||
|
||||
main_win.flip()
|
||||
|
||||
# Stage 2: Flanker Stimulus
|
||||
# print('Flanker Stage: flank all frequencies')
|
||||
client.send_data('predict', 1)
|
||||
client.send_data('setLabelAndTrialInfo', {
|
||||
'label': target_freq_index + 1, # 设置目标频率标签 这里+1,是因为0代表不记录数据
|
||||
'trial_info': {
|
||||
'trial': trial_count + 1, # trial 从0开始
|
||||
'phase': 'stimulus',
|
||||
'target_freq': target_freq
|
||||
}
|
||||
})
|
||||
outlet.push_sample(['S 1'])
|
||||
|
||||
for frameN in range(6 * REFRESH_RATE): # 6秒刺激
|
||||
key_press = get_keypress()
|
||||
if key_press in ['q']:
|
||||
shutdown(main_win, client)
|
||||
|
||||
# 所有频率按照方波闪烁
|
||||
if square_wave_9[frameN % len(square_wave_9)] == 1:
|
||||
imageStim1.draw()
|
||||
if square_wave_11[frameN % len(square_wave_11)] == 1:
|
||||
imageStim2.draw()
|
||||
if square_wave_12[frameN % len(square_wave_12)] == 1:
|
||||
imageStim3.draw()
|
||||
if square_wave_13[frameN % len(square_wave_13)] == 1:
|
||||
imageStim4.draw()
|
||||
if square_wave_14[frameN % len(square_wave_14)] == 1:
|
||||
imageStim5.draw()
|
||||
if square_wave_15[frameN % len(square_wave_15)] == 1:
|
||||
imageStim6.draw()
|
||||
|
||||
main_win.flip()
|
||||
if server.ChoosenNum != -1:
|
||||
break
|
||||
|
||||
# 记录在线解码结果
|
||||
predicted_freq_index = server.ChoosenNum # 解码结果
|
||||
predicted_freq = frequencies[predicted_freq_index] if predicted_freq_index != -1 else -1
|
||||
|
||||
# 判断解码是否正确
|
||||
is_correct = (predicted_freq_index == target_freq_index) if predicted_freq_index != -1 else False
|
||||
if is_correct:
|
||||
correct_predictions += 1
|
||||
|
||||
# 记录trial结果
|
||||
trial_result = {
|
||||
'trial': trial_count + 1,
|
||||
'target_freq': target_freq,
|
||||
'target_freq_index': target_freq_index,
|
||||
'predicted_freq': predicted_freq,
|
||||
'predicted_freq_index': predicted_freq_index,
|
||||
'is_correct': is_correct,
|
||||
'status': 'Success' if predicted_freq_index != -1 else 'Failed'
|
||||
}
|
||||
online_results.append(trial_result)
|
||||
|
||||
# 打印当前trial结果
|
||||
status_symbol = "✓" if is_correct else "✗"
|
||||
if predicted_freq_index == -1:
|
||||
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 解码失败 - {status_symbol}')
|
||||
else:
|
||||
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 预测{predicted_freq}Hz - {status_symbol}')
|
||||
|
||||
|
||||
# Stage 3: Decoding Feedback
|
||||
outlet.push_sample(['S 2'])
|
||||
client.send_data('setLabelAndTrialInfo', {
|
||||
'label': 0, # 反馈阶段标签为0
|
||||
'trial_info': {
|
||||
'trial': trial_count + 1,
|
||||
'phase': 'feedback',
|
||||
'target_freq': target_freq
|
||||
}
|
||||
})
|
||||
# print('反馈阶段: 显示解码结果')
|
||||
|
||||
for frameN in range(1 * REFRESH_RATE): # 1秒反馈
|
||||
key_press = get_keypress()
|
||||
if key_press in ['q']:
|
||||
shutdown(main_win, client)
|
||||
|
||||
# 显示所有刺激但不闪烁
|
||||
for stim in image_stims:
|
||||
stim.draw()
|
||||
|
||||
# 显示解码结果
|
||||
if server.ChoosenNum == 0:
|
||||
txtStim1.draw()
|
||||
elif server.ChoosenNum == 1:
|
||||
txtStim2.draw()
|
||||
elif server.ChoosenNum == 2:
|
||||
txtStim3.draw()
|
||||
elif server.ChoosenNum == 3:
|
||||
txtStim4.draw()
|
||||
elif server.ChoosenNum == 4:
|
||||
txtStim5.draw()
|
||||
elif server.ChoosenNum == 5:
|
||||
txtStim6.draw()
|
||||
|
||||
main_win.flip()
|
||||
|
||||
server.ChoosenNum = -1
|
||||
trial_count += 1
|
||||
|
||||
# 计算总体在线解码精度
|
||||
total_trials = len(online_results)
|
||||
successful_trials = len([r for r in online_results if r['status'] == 'Success'])
|
||||
failed_trials = len([r for r in online_results if r['status'] == 'Failed'])
|
||||
overall_accuracy = correct_predictions / total_trials if total_trials > 0 else 0
|
||||
|
||||
# Print Accuracy
|
||||
print(f"Total Accuracy: {overall_accuracy:.3f} ({correct_predictions}/{total_trials})")
|
||||
|
||||
# 按频率分析准确率
|
||||
print(f"\n=== 按频率分析准确率 ===")
|
||||
freq_accuracy = {}
|
||||
for result in online_results:
|
||||
freq = result['target_freq']
|
||||
if freq not in freq_accuracy:
|
||||
freq_accuracy[freq] = {'correct': 0, 'total': 0, 'failed': 0}
|
||||
|
||||
freq_accuracy[freq]['total'] += 1
|
||||
if result['status'] == 'Failed':
|
||||
freq_accuracy[freq]['failed'] += 1
|
||||
elif result['is_correct']:
|
||||
freq_accuracy[freq]['correct'] += 1
|
||||
|
||||
print(f"{'频率':<8} {'准确率':<8} {'正确/总数':<10} {'失败数':<8}")
|
||||
print("-" * 40)
|
||||
for freq in sorted(freq_accuracy.keys()):
|
||||
stats = freq_accuracy[freq]
|
||||
accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
|
||||
print(f"{freq}Hz{'':<4} {accuracy:.3f}{'':<4} {stats['correct']}/{stats['total']}{'':<6} {stats['failed']}")
|
||||
|
||||
# 保存在线解码结果到文件
|
||||
online_results_file = f'EEGFiles/online_results_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||
online_summary = {
|
||||
'total_trials': total_trials,
|
||||
'successful_trials': successful_trials,
|
||||
'failed_trials': failed_trials,
|
||||
'correct_predictions': correct_predictions,
|
||||
'overall_accuracy': overall_accuracy,
|
||||
# 'freq_accuracy': freq_accuracy,
|
||||
'trial_results': online_results,
|
||||
# 'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
with open(online_results_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(online_summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
client.send_data('predict',2) # 关闭系统
|
||||
main_win.close()
|
||||
@@ -1,304 +0,0 @@
|
||||
"""
|
||||
datamock 验证脚本(模拟算法端)
|
||||
作为 ZMQ ROUTER 监听 8100 端口,等待 datamock.py 连接并验证数据流
|
||||
|
||||
运行顺序:
|
||||
第一步: python verify_datamock.py (先启动,监听 8100)
|
||||
第二步: python datamock.py (后启动,连接 8100)
|
||||
"""
|
||||
import zmq
|
||||
import numpy as np
|
||||
import time
|
||||
import sys
|
||||
import matplotlib
|
||||
matplotlib.use('TkAgg')
|
||||
|
||||
# 在导入 pyplot 之前确保 Tkinter 正确初始化
|
||||
try:
|
||||
import tkinter as tk
|
||||
root = tk.Tk()
|
||||
root.withdraw() # 隐藏主窗口,我们只需要它的事件循环
|
||||
except Exception as e:
|
||||
print(f"[WARN] Tkinter 初始化警告: {e}")
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from datetime import datetime
|
||||
|
||||
# ===== 可视化参数 =====
|
||||
PLOT_WINDOW_SEC = 2.0 # 滑动窗口时长(秒)
|
||||
PLOT_CHANNELS = [0, 1, 2, 3] # 要显示的 EEG 通道索引
|
||||
|
||||
SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
||||
FS = 250
|
||||
N_SAMPLES_PER_PKT = 5
|
||||
N_CHAN = 66
|
||||
EEG_FREQ = 10
|
||||
EEG_AMP = 100.0 # EEG 幅值 100μV(峰值)
|
||||
EEG_AMP_MEAN = EEG_AMP * 2 / np.pi # 正弦波 |mean| ≈ 63.7μV
|
||||
EEG_AMP_TOLERANCE = 1.5 # 幅值容差倍数
|
||||
LABEL_INTERVAL = 5
|
||||
FFT_SAMPLES = 250 # 做一次 FFT 需要的采样点数(1s数据)
|
||||
EXPECTED_BYTES = N_SAMPLES_PER_PKT * N_CHAN * 4 # 1320 bytes (5*66*4)
|
||||
|
||||
|
||||
def validate_fft(samples):
|
||||
"""对 Ch0 数据做 FFT,返回峰值频率"""
|
||||
freqs = np.fft.rfftfreq(FFT_SAMPLES, d=1 / FS)
|
||||
fft_mag = np.abs(np.fft.rfft(samples))
|
||||
peak_idx = np.argmax(fft_mag[1:]) + 1 # 跳过 DC
|
||||
return freqs[peak_idx], fft_mag, freqs
|
||||
|
||||
|
||||
def main():
|
||||
ctx = zmq.Context()
|
||||
sock = ctx.socket(zmq.ROUTER)
|
||||
sock.bind(SERVER_ADDR)
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ ROUTER 绑定 {SERVER_ADDR},等待 datamock.py 连接...\n")
|
||||
|
||||
# ===== 初始化交互式绘图 =====
|
||||
plt.ion() # 开启交互模式
|
||||
fig = plt.figure(figsize=(14, 10))
|
||||
fig.suptitle('EEG Data Monitor (Real-time)', fontsize=14)
|
||||
|
||||
# 使用 GridSpec 进行布局
|
||||
from matplotlib.gridspec import GridSpec
|
||||
gs = GridSpec(len(PLOT_CHANNELS) + 2, 1, figure=fig, hspace=0.3)
|
||||
axes = []
|
||||
lines_eeg = []
|
||||
for i, ch in enumerate(PLOT_CHANNELS):
|
||||
ax = fig.add_subplot(gs[i])
|
||||
axes.append(ax)
|
||||
ax.set_ylabel(f'Ch{ch} (μV)', fontsize=8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_ylim(-150, 150)
|
||||
line, = ax.plot([], [], lw=0.8)
|
||||
lines_eeg.append(line)
|
||||
ax.set_title(f'EEG Channel {ch}', fontsize=9)
|
||||
|
||||
# 标签通道子图 (Ch64 - 标签值)
|
||||
ax_label = fig.add_subplot(gs[len(PLOT_CHANNELS)])
|
||||
axes.append(ax_label)
|
||||
ax_label.set_ylabel('Label Value', fontsize=8)
|
||||
ax_label.grid(True, alpha=0.3)
|
||||
ax_label.set_ylim(-0.5, 2.5)
|
||||
line_label, = ax_label.plot([], [], 'ro-', lw=1.5, markersize=4)
|
||||
line_label_data = line_label
|
||||
ax_label.set_title('Ch64 - Label Value', fontsize=9)
|
||||
|
||||
# Ch65 标签序号子图
|
||||
ax_seq = fig.add_subplot(gs[len(PLOT_CHANNELS) + 1])
|
||||
axes.append(ax_seq)
|
||||
ax_seq.set_ylabel('Label Seq', fontsize=8)
|
||||
ax_seq.set_xlabel('Time (samples)', fontsize=8)
|
||||
ax_seq.grid(True, alpha=0.3)
|
||||
ax_seq.set_ylim(-0.5, 10)
|
||||
line_seq, = ax_seq.plot([], [], 'gs-', lw=1.5, markersize=4)
|
||||
line_seq_data = line_seq
|
||||
ax_seq.set_title('Ch65 - Label Sequence', fontsize=9)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# ===== 状态 =====
|
||||
global_idx = 0 # 全局采样点索引
|
||||
label_events = [] # 捕获的标签事件
|
||||
start_time = None
|
||||
fft_done = False
|
||||
fft_buffer = [] # 暂存前 250 点做 FFT
|
||||
ch64_zero_ok = True # 验证 Ch64 非标签采样点均为 0
|
||||
ch65_zero_ok = True # 验证 Ch65 非标签采样点均为 0
|
||||
label_pos_ok_all = True # 验证标签均在包内索引 4
|
||||
|
||||
# ===== 数据缓冲区 =====
|
||||
max_samples = int(FS * PLOT_WINDOW_SEC)
|
||||
eeg_buffer = {ch: np.zeros(max_samples) for ch in PLOT_CHANNELS}
|
||||
label_buffer = np.zeros(max_samples)
|
||||
seq_buffer = np.zeros(max_samples)
|
||||
time_axis = np.arange(max_samples)
|
||||
|
||||
# ZMQ 收发统计
|
||||
recv_count = 0
|
||||
|
||||
try:
|
||||
# 首次 pause 用于显示窗口
|
||||
plt.pause(0.5)
|
||||
print(f"[INFO] 交互窗口已显示,如未看到请检查任务栏")
|
||||
|
||||
while True:
|
||||
# ROUTER recv: prepended 一个 identity 帧
|
||||
# datamock 发送 3帧 [b'datamock', b'', data_bytes]
|
||||
# ROUTER 接收后变成 4帧 [router_identity, b'datamock', b'', data_bytes]
|
||||
frames = sock.recv_multipart()
|
||||
recv_count += 1
|
||||
now = time.time()
|
||||
if start_time is None:
|
||||
start_time = now
|
||||
|
||||
# 帧格式: [router_identity, b'datamock', b'', data_bytes]
|
||||
router_id = frames[0] # ROUTER 添加的身份帧
|
||||
identity = frames[1] # 发送端的 identity
|
||||
_empty = frames[2] # 空帧
|
||||
raw_data = frames[3] # 实际数据字节
|
||||
|
||||
# 数据长度校验
|
||||
if len(raw_data) != EXPECTED_BYTES:
|
||||
print(f"[ERROR] 数据长度错误: 期望{EXPECTED_BYTES}字节, 实际{len(raw_data)}字节")
|
||||
continue
|
||||
|
||||
# 解析为 [5, 66] float32 数组
|
||||
packet = np.frombuffer(raw_data, dtype=np.float32).reshape(N_SAMPLES_PER_PKT, N_CHAN)
|
||||
|
||||
elapsed = now - start_time
|
||||
|
||||
# ===== 验证 1: 数据形状 =====
|
||||
if recv_count == 1:
|
||||
shape_ok = packet.shape == (N_SAMPLES_PER_PKT, N_CHAN)
|
||||
print(f"[{'✓' if shape_ok else '✗'}] 数据形状: {packet.shape} "
|
||||
f"(期望 [{N_SAMPLES_PER_PKT}, {N_CHAN}])")
|
||||
if not shape_ok:
|
||||
print(f" ✗ 形状不匹配,退出")
|
||||
break
|
||||
|
||||
# ===== 验证 2: EEG 幅值(首包) =====
|
||||
if recv_count == 1:
|
||||
eeg = packet[:, :64]
|
||||
amp_mean = np.mean(np.abs(eeg))
|
||||
amp_ok = amp_mean <= EEG_AMP_MEAN * EEG_AMP_TOLERANCE
|
||||
print(f"[{'✓' if amp_ok else '✗'}] EEG 幅值: 均值={amp_mean:.2f}μV "
|
||||
f"(期望 ~{EEG_AMP_MEAN:.2f}μV,峰值 ~{EEG_AMP:.2f}μV)")
|
||||
if not amp_ok:
|
||||
print(f" ✗ 幅值超出容差范围")
|
||||
|
||||
# ===== 验证 3: EEG 频率(首秒数据收集满后做 FFT) =====
|
||||
fft_buffer.append(packet[:, 0].copy()) # 收集 Ch0
|
||||
|
||||
if not fft_done and len(fft_buffer) * N_SAMPLES_PER_PKT >= FFT_SAMPLES:
|
||||
# 凑够 250 点,做 FFT
|
||||
all_ch0 = np.concatenate(fft_buffer)[:FFT_SAMPLES]
|
||||
peak_freq, fft_mag, freqs = validate_fft(all_ch0)
|
||||
freq_ok = abs(peak_freq - EEG_FREQ) < 1.0
|
||||
|
||||
print(f"[{'✓' if freq_ok else '✗'}] EEG 频率: 峰值={peak_freq:.1f}Hz "
|
||||
f"(期望 ~{EEG_FREQ}Hz)")
|
||||
print(f" FFT 幅度谱前 5 峰值:")
|
||||
top5 = np.argsort(fft_mag[1:])[-5:][::-1] + 1
|
||||
for rank, idx in enumerate(top5):
|
||||
print(f" {rank+1}. {freqs[idx]:.1f}Hz 幅度={fft_mag[idx]:.1f}")
|
||||
print()
|
||||
fft_done = True
|
||||
|
||||
# ===== 验证 4: 标签通道(Ch64/Ch65) =====
|
||||
ch64 = packet[:, 64]
|
||||
ch65 = packet[:, 65]
|
||||
ch64_nonzero = np.where(ch64 != 0)[0]
|
||||
ch65_nonzero = np.where(ch65 != 0)[0]
|
||||
|
||||
# 检查非标签采样点是否全为 0
|
||||
ch64_zeros = np.all(ch64[:4] == 0)
|
||||
ch65_zeros = np.all(ch65[:4] == 0)
|
||||
ch64_zero_ok = ch64_zero_ok and ch64_zeros
|
||||
ch65_zero_ok = ch65_zero_ok and ch65_zeros
|
||||
|
||||
if len(ch64_nonzero) > 0:
|
||||
pos_in_pkt = int(ch64_nonzero[0])
|
||||
label_val = int(ch64[pos_in_pkt])
|
||||
label_seq = int(ch65[pos_in_pkt])
|
||||
|
||||
pos_ok = (len(ch64_nonzero) == 1 and pos_in_pkt == 4)
|
||||
label_pos_ok_all = label_pos_ok_all and pos_ok
|
||||
|
||||
elapsed_since_start = now - start_time
|
||||
print(f"[✓] 标签触发 @ {elapsed_since_start:.1f}s "
|
||||
f"(global_idx={global_idx} 包{recv_count})")
|
||||
print(f" Ch64 标签值: {label_val} Ch65 序号: {label_seq}")
|
||||
print(f" 包内位置: 采样点 {pos_in_pkt}/4 "
|
||||
f"({'✓' if pos_ok else '✗ 期望 4'}) "
|
||||
f"其余采样点 Ch64=0: {'✓' if ch64_zeros else '✗'} "
|
||||
f"Ch65=0: {'✓' if ch65_zeros else '✗'}")
|
||||
print()
|
||||
|
||||
label_events.append({
|
||||
'time': elapsed_since_start,
|
||||
'label': label_val,
|
||||
'seq': label_seq
|
||||
})
|
||||
|
||||
global_idx += N_SAMPLES_PER_PKT
|
||||
|
||||
# ===== 更新绘图缓冲区 =====
|
||||
for ch_idx, ch in enumerate(PLOT_CHANNELS):
|
||||
eeg_buffer[ch] = np.roll(eeg_buffer[ch], -N_SAMPLES_PER_PKT)
|
||||
eeg_buffer[ch][-N_SAMPLES_PER_PKT:] = packet[:, ch]
|
||||
|
||||
label_buffer = np.roll(label_buffer, -N_SAMPLES_PER_PKT)
|
||||
label_buffer[-N_SAMPLES_PER_PKT:] = packet[:, 64]
|
||||
|
||||
seq_buffer = np.roll(seq_buffer, -N_SAMPLES_PER_PKT)
|
||||
seq_buffer[-N_SAMPLES_PER_PKT:] = packet[:, 65]
|
||||
|
||||
# ===== 实时更新绘图 =====
|
||||
for i, ch in enumerate(PLOT_CHANNELS):
|
||||
lines_eeg[i].set_data(time_axis, eeg_buffer[ch]) # 数据已是 μV 单位
|
||||
line_label_data.set_data(time_axis, label_buffer)
|
||||
line_seq_data.set_data(time_axis, seq_buffer)
|
||||
|
||||
# 设置 x 轴范围
|
||||
for ax in axes:
|
||||
ax.set_xlim(0, max_samples)
|
||||
|
||||
# 刷新图形(交互模式)
|
||||
fig.canvas.draw_idle()
|
||||
plt.pause(0.001)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n" + "=" * 55)
|
||||
print(" 验证结果汇总")
|
||||
print("=" * 55)
|
||||
print(f" 运行时长: {time.time() - start_time:.1f}s")
|
||||
print(f" 收到包数: {recv_count}")
|
||||
print(f" FFT 验证: {'✓ 已完成' if fft_done else '✗ 未完成(时长不足1s)'}")
|
||||
print(f" 非标签采样点 Ch64=0: {'✓' if ch64_zero_ok else '✗'}")
|
||||
print(f" 非标签采样点 Ch65=0: {'✓' if ch65_zero_ok else '✗'}")
|
||||
print(f" 标签均在包内位置4: {'✓' if label_pos_ok_all else '✗'}")
|
||||
|
||||
if label_events:
|
||||
print(f"\n 共捕获 {len(label_events)} 次标签事件:")
|
||||
for i, ev in enumerate(label_events):
|
||||
print(f" {i+1}. t={ev['time']:.1f}s label={ev['label']} 序号={ev['seq']}")
|
||||
|
||||
# 标签间隔
|
||||
print(f"\n 标签间隔验证 (期望 ~{LABEL_INTERVAL}s):")
|
||||
for i in range(1, len(label_events)):
|
||||
dt = label_events[i]['time'] - label_events[i-1]['time']
|
||||
ok = abs(dt - LABEL_INTERVAL) < 0.1
|
||||
print(f" {i}->{i+1}: {dt:.2f}s {'✓' if ok else '✗'}")
|
||||
|
||||
# 标签交替
|
||||
labels = [e['label'] for e in label_events]
|
||||
alt_ok = all(labels[i] != labels[i+1] for i in range(len(labels) - 1))
|
||||
print(f"\n 标签交替: {labels} {'✓ 交替正确' if alt_ok else '✗ 交替错误'}")
|
||||
|
||||
# 序号
|
||||
label1_seqs = [e['seq'] for e in label_events if e['label'] == 1]
|
||||
label2_seqs = [e['seq'] for e in label_events if e['label'] == 2]
|
||||
s1_ok = label1_seqs == list(range(1, len(label1_seqs) + 1))
|
||||
s2_ok = label2_seqs == list(range(1, len(label2_seqs) + 1))
|
||||
print(f" label=1 序号: {label1_seqs} {'✓' if s1_ok else '✗'}")
|
||||
print(f" label=2 序号: {label2_seqs} {'✓' if s2_ok else '✗'}")
|
||||
else:
|
||||
print(f"\n 未捕获标签事件(运行时长不足 {LABEL_INTERVAL}s)")
|
||||
|
||||
print("=" * 55)
|
||||
|
||||
finally:
|
||||
sock.close()
|
||||
ctx.term()
|
||||
plt.ioff()
|
||||
plt.close('all')
|
||||
try:
|
||||
root.destroy()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user