2026-06-05 09:34:29 +08:00
|
|
|
|
import ast
|
2026-06-06 14:51:00 +08:00
|
|
|
|
import glob
|
2026-06-06 14:40:07 +08:00
|
|
|
|
import os
|
|
|
|
|
|
import sys
|
2026-06-05 09:34:29 +08:00
|
|
|
|
import threading
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import time
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from queue import Empty
|
|
|
|
|
|
from scipy import signal
|
|
|
|
|
|
from torch.autograd import Variable
|
2026-06-06 14:40:07 +08:00
|
|
|
|
# from Device.SunnyLinker import SunnyLinker64
|
2026-06-05 09:34:29 +08:00
|
|
|
|
from SSMVEP.algorithm.tdca import TDCA
|
|
|
|
|
|
from SSMVEP.algorithm.base import generate_cca_references
|
|
|
|
|
|
from concentration.algorithm.calculate_focus import Calculate
|
|
|
|
|
|
from blinkdetection.algorithm.eye_detection import blink_detection
|
|
|
|
|
|
from Zmq.zmqServer import zmqServer
|
|
|
|
|
|
from Zmq.zmqClient import zmqClient
|
|
|
|
|
|
from MI.Algorithm.conformer_2class import onlineTrain
|
|
|
|
|
|
from PubLibrary.InifileHelper import IniRead
|
2026-06-06 14:40:07 +08:00
|
|
|
|
from logs.log import algo_log
|
2026-06-05 09:34:29 +08:00
|
|
|
|
from SSVEP.dwfbcca import FbccaDw
|
2026-06-06 14:40:07 +08:00
|
|
|
|
# from Tools.plot_MI_EEG import plotMain
|
2026-06-05 09:34:29 +08:00
|
|
|
|
from collections import deque
|
2026-06-07 11:05:24 +08:00
|
|
|
|
from Zmq.filterProcess import SlidingFilter
|
2026-06-06 14:40:07 +08:00
|
|
|
|
|
|
|
|
|
|
def get_root_path():
|
|
|
|
|
|
"""
|
|
|
|
|
|
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
|
|
|
|
|
|
"""
|
|
|
|
|
|
if getattr(sys, 'frozen', False):
|
|
|
|
|
|
# 打包后:返回 exe 所在目录
|
|
|
|
|
|
return os.path.dirname(sys.executable)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 开发时:返回 py 文件所在目录
|
|
|
|
|
|
return os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
MODEL_FOLDER = "online_Models"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Decoder_main(threading.Thread):
|
|
|
|
|
|
def __init__(self, device_info=None):
|
2026-06-05 09:34:29 +08:00
|
|
|
|
threading.Thread.__init__(self)
|
2026-06-08 11:56:42 +08:00
|
|
|
|
self.device_info = device_info
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.Runing=True
|
|
|
|
|
|
self.decoder = None
|
|
|
|
|
|
self.decoder_class = None #解码器类别
|
|
|
|
|
|
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-06-08 11:56:42 +08:00
|
|
|
|
self.zmqServer = None
|
|
|
|
|
|
self.sliding_filter = None
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
2026-06-08 11:56:42 +08:00
|
|
|
|
self._init_threads()
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
2026-06-08 11:56:42 +08:00
|
|
|
|
def _init_threads(self):
|
|
|
|
|
|
"""初始化ZMQ服务和滤波线程"""
|
|
|
|
|
|
# 1. 初始化ZMQServer并启动
|
|
|
|
|
|
self.zmqServer = zmqServer(device_info=self.device_info)
|
|
|
|
|
|
self.zmqServer.start() # 启动ZMQ接收线程
|
2026-06-06 14:40:07 +08:00
|
|
|
|
|
2026-06-08 11:56:42 +08:00
|
|
|
|
# 2. 初始化滤波线程(关联ZMQServer的环形缓存)
|
|
|
|
|
|
self.sliding_filter = SlidingFilter(
|
|
|
|
|
|
ring_buffer=self.zmqServer.filterBuffer,
|
|
|
|
|
|
n_chan=self.zmqServer.device_info['channel_nums'],
|
|
|
|
|
|
srate=self.zmqServer.device_info['sample_rate']
|
|
|
|
|
|
)
|
2026-06-06 14:40:07 +08:00
|
|
|
|
|
2026-06-08 11:56:42 +08:00
|
|
|
|
# 注册滤波结果回调(示例:打印数据形状)
|
|
|
|
|
|
self.sliding_filter.set_result_callback(self.zmqServer.send_filtered_data)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
|
|
|
|
|
# data: (chans, samples)
|
|
|
|
|
|
energy = np.mean(np.var(data, axis=1)) # 各通道方差均值
|
|
|
|
|
|
if energy > threshold:
|
|
|
|
|
|
return False
|
|
|
|
|
|
return True
|
|
|
|
|
|
def init_Decoder(self,decoder_class):
|
|
|
|
|
|
'''
|
|
|
|
|
|
初始化解码器
|
|
|
|
|
|
:param decoder_class: 'ssvep' or 'ssmvep' or 'mi' or 'concentration' or ''
|
|
|
|
|
|
:return:
|
|
|
|
|
|
'''
|
|
|
|
|
|
self.decoder_class = decoder_class
|
|
|
|
|
|
if decoder_class == 'ssvep' or decoder_class == 'pvs':
|
|
|
|
|
|
self.n_chan = 8
|
2026-06-06 14:40:07 +08:00
|
|
|
|
# self.thread_data_server.interval_inited = False
|
2026-06-05 09:34:29 +08:00
|
|
|
|
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
|
|
|
|
|
self.ListFreq = self.zmqServer.targetFreqs
|
|
|
|
|
|
self.num_target = len(self.ListFreq)
|
|
|
|
|
|
if self.num_target == 0:
|
|
|
|
|
|
return
|
|
|
|
|
|
# 初始化对象 二代算法
|
2026-06-06 14:40:07 +08:00
|
|
|
|
self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
|
2026-06-05 09:34:29 +08:00
|
|
|
|
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
|
|
|
|
|
|
# frequence band
|
|
|
|
|
|
self.dw.filterFrequenceBank()
|
|
|
|
|
|
self.dw.setNotchFilterPara()
|
|
|
|
|
|
self.calculateCount = 0
|
2026-06-06 14:40:07 +08:00
|
|
|
|
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.dw.filterInit()
|
|
|
|
|
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
|
|
|
|
|
|
|
|
|
|
|
elif decoder_class == 'ssmvep':
|
2026-06-06 14:40:07 +08:00
|
|
|
|
self.zmqServer.interval_init(decoder_class)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.n_chan = 8
|
|
|
|
|
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
|
|
|
|
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
|
|
|
|
|
self.single_train = 10 # 单类别数量
|
|
|
|
|
|
self.num_target = 2 # 分类目标数目
|
|
|
|
|
|
self.list_freqs = np.array([8, 9]) # 刺激频率
|
|
|
|
|
|
self.list_phase = np.array([0, 0]) # 相位
|
|
|
|
|
|
self.tdca = TDCA(padding_len=5, n_components=1)
|
2026-06-06 14:40:07 +08:00
|
|
|
|
self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
|
2026-06-05 09:34:29 +08:00
|
|
|
|
phases=self.list_phase, n_harmonics=5)
|
|
|
|
|
|
self.parameter_init(5,45)
|
|
|
|
|
|
|
|
|
|
|
|
elif decoder_class == 'mi' or decoder_class == 'ma':
|
2026-06-06 14:40:07 +08:00
|
|
|
|
self.zmqServer.interval_init(decoder_class)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.n_chan = 21
|
|
|
|
|
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
|
|
|
|
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
|
|
|
|
|
self.single_train = 40 # 单类别数量
|
|
|
|
|
|
self.num_target = 2 # 分类目标数目
|
|
|
|
|
|
|
|
|
|
|
|
self.parameter_init(8, 30)
|
|
|
|
|
|
|
2026-06-06 09:16:49 +08:00
|
|
|
|
# elif decoder_class == 'concentration':
|
|
|
|
|
|
# self.thread_data_server.interval_inited = False
|
|
|
|
|
|
# self.n_chan = 6
|
|
|
|
|
|
# self.win_len = 10
|
|
|
|
|
|
# self.win_step = 1
|
|
|
|
|
|
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
|
2026-06-06 14:40:07 +08:00
|
|
|
|
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
|
2026-06-06 09:16:49 +08:00
|
|
|
|
# self.interval_epoch = [0, 1]
|
|
|
|
|
|
# self.parameter_init(2, 40)
|
|
|
|
|
|
# # self.eegQueue moved to Calculate class
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
2026-06-06 09:16:49 +08:00
|
|
|
|
# elif decoder_class == 'blink':
|
|
|
|
|
|
# self.n_chan = 2
|
|
|
|
|
|
# self.l_freq = 0.1 # 带通滤波器低频截止
|
|
|
|
|
|
# self.h_freq = 8.0 # 带通滤波器高频截止
|
|
|
|
|
|
# self.total_samples = 0 # 总采样点数
|
|
|
|
|
|
# self.window_ms = 600 # 检测窗口大小 (ms)
|
|
|
|
|
|
# self.step_ms = 100 # 滑动步长 (ms)
|
2026-06-06 14:40:07 +08:00
|
|
|
|
# self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
|
|
|
|
|
|
# self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
|
2026-06-06 09:16:49 +08:00
|
|
|
|
# self.buffer_size = self.window_samples + self.step_samples * 5
|
|
|
|
|
|
# self.fp1_buffer = deque(maxlen=self.buffer_size)
|
|
|
|
|
|
# self.fp2_buffer = deque(maxlen=self.buffer_size)
|
|
|
|
|
|
# self.sample_counter = 0
|
|
|
|
|
|
# # 预计算滤波器系数,避免在循环中重复设计
|
|
|
|
|
|
# self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink'))
|
|
|
|
|
|
# self.blink_count = 0 # 单次眨眼的次数
|
|
|
|
|
|
# self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引)
|
|
|
|
|
|
# self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳
|
|
|
|
|
|
# self.double_blink_count = 0 # 连续两次眨眼的次数
|
|
|
|
|
|
# self.double_blink_events = [] # 连续眨眼事件记录
|
|
|
|
|
|
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
|
|
|
|
|
|
# self.blink_events = []
|
2026-06-06 14:40:07 +08:00
|
|
|
|
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
|
|
|
|
|
def parameter_init(self,bandPass_low,bandPass_high):
|
2026-06-06 14:40:07 +08:00
|
|
|
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
|
|
|
|
|
|
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.trainData = [] #训练数据
|
|
|
|
|
|
self.trainLabel = [] #训练标签
|
|
|
|
|
|
self.plotData = [] #报告分析数据
|
|
|
|
|
|
self.plotLabel = [] #报告分析标签
|
|
|
|
|
|
self.currentLabel = -1 #刺激界面当前显示的训练标签
|
|
|
|
|
|
self.train_started = False #是否开始训练模型
|
|
|
|
|
|
self.load_model = False # 调用模型是否完成的标志
|
2026-06-06 14:40:07 +08:00
|
|
|
|
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
|
|
|
|
|
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
2026-06-06 15:13:23 +08:00
|
|
|
|
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
|
2026-06-06 14:51:00 +08:00
|
|
|
|
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
|
|
|
|
|
|
os.remove(old_pth)
|
2026-06-07 11:05:24 +08:00
|
|
|
|
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
2026-06-06 14:51:00 +08:00
|
|
|
|
self.mp_data_queue = mp.Queue()
|
|
|
|
|
|
self.mp_result_queue = mp.Queue()
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
|
|
|
|
|
def preprocess(self, signal_data):
|
|
|
|
|
|
# # 计算每行的平均值
|
|
|
|
|
|
row_means = np.mean(signal_data, axis=-1, keepdims=True)
|
|
|
|
|
|
# 对每一行去均值
|
|
|
|
|
|
signal_data = signal_data - row_means
|
|
|
|
|
|
|
|
|
|
|
|
signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波
|
|
|
|
|
|
signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波
|
|
|
|
|
|
return signal_data
|
|
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
|
while self.Runing:
|
2026-06-08 11:56:42 +08:00
|
|
|
|
# 当滤波数据大于5秒时,启动滤波线程
|
|
|
|
|
|
if self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5:
|
|
|
|
|
|
self.sliding_filter.start()
|
|
|
|
|
|
|
2026-06-05 09:34:29 +08:00
|
|
|
|
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
|
|
|
|
|
|
print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}")
|
|
|
|
|
|
self.zmqServer.decoder_switch = False
|
|
|
|
|
|
self.zmqServer.changeTarget = False
|
|
|
|
|
|
self.reset_state() # 切换前先统一清理旧状态
|
|
|
|
|
|
self.init_Decoder(self.zmqServer.decoder_class)
|
|
|
|
|
|
|
|
|
|
|
|
# 同步信息
|
|
|
|
|
|
if self.zmqServer.state_mode == 'sync':
|
2026-06-06 14:40:07 +08:00
|
|
|
|
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.zmqServer.state_mode = 'rest'
|
2026-06-06 14:40:07 +08:00
|
|
|
|
|
2026-06-05 09:34:29 +08:00
|
|
|
|
# --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 ---
|
|
|
|
|
|
try:
|
|
|
|
|
|
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
|
|
|
|
|
self.decoder_SSVEP()
|
|
|
|
|
|
elif self.decoder_class == 'ssmvep':
|
|
|
|
|
|
self.decoder_SSMVEP()
|
|
|
|
|
|
elif self.decoder_class == 'mi':
|
|
|
|
|
|
self.decoder_MI()
|
2026-06-07 11:05:24 +08:00
|
|
|
|
else:
|
|
|
|
|
|
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
|
|
|
|
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
|
|
|
|
|
time.sleep(0.005)
|
|
|
|
|
|
continue;
|
|
|
|
|
|
self.zmqServer.paradigmBuffer.getData(25)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
except Exception as e:
|
2026-06-06 14:40:07 +08:00
|
|
|
|
algo_log(f"Decoder Loop Error: {e}")
|
2026-06-05 09:34:29 +08:00
|
|
|
|
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
|
|
|
|
|
|
|
|
|
|
|
def decoder_SSVEP(self):
|
|
|
|
|
|
if self.zmqServer.StartDecode:
|
|
|
|
|
|
self.zmqServer.StartDecode = False
|
|
|
|
|
|
self.decodingSteps = 1
|
2026-06-06 17:08:09 +08:00
|
|
|
|
self.zmqServer.paradigmBuffer.resetAllPara()
|
2026-06-05 09:34:29 +08:00
|
|
|
|
print('启动预测')
|
2026-06-06 14:40:07 +08:00
|
|
|
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
|
2026-06-05 09:34:29 +08:00
|
|
|
|
time.sleep(0.005)
|
|
|
|
|
|
return
|
2026-06-06 17:08:09 +08:00
|
|
|
|
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
|
2026-06-05 09:34:29 +08:00
|
|
|
|
return
|
2026-06-06 14:40:07 +08:00
|
|
|
|
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
data = data[:self.n_chan, :]
|
|
|
|
|
|
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
|
|
|
|
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
|
|
|
|
|
self.dw.warmFilter(data) # 预热
|
|
|
|
|
|
self.decodingSteps = 2
|
|
|
|
|
|
print('预热数据完成。开始预测')
|
|
|
|
|
|
return
|
|
|
|
|
|
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
|
|
|
|
|
|
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
|
|
|
|
|
|
self.calculateCount += 1
|
|
|
|
|
|
if choosenNum != -1 and self.is_valid_signal(data):
|
|
|
|
|
|
self.decodingSteps = 3
|
|
|
|
|
|
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
|
|
|
|
|
|
self.calculateCount = 0
|
|
|
|
|
|
if self.decodingSteps == 3: # 发送解码后的信息
|
2026-06-06 14:40:07 +08:00
|
|
|
|
self.zmqServer.broadcast_message('result', int(choosenNum))
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.decodingSteps = 0
|
|
|
|
|
|
print('发送给界面完成。')
|
|
|
|
|
|
|
|
|
|
|
|
def decoder_SSMVEP(self):
|
|
|
|
|
|
'''模型训练'''
|
|
|
|
|
|
if self.load_model == False and all(
|
|
|
|
|
|
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
|
|
|
|
|
|
self.trainData = np.array(self.trainData)
|
|
|
|
|
|
self.trainLabel = np.array(self.trainLabel)
|
|
|
|
|
|
print(np.shape(self.trainData), (self.trainLabel))
|
|
|
|
|
|
# 保存多个数组到文件
|
|
|
|
|
|
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel)
|
|
|
|
|
|
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
|
|
|
|
|
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
|
|
|
|
|
now = datetime.now()
|
|
|
|
|
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
|
|
|
|
|
print('模型训练完成', formatted_time)
|
|
|
|
|
|
self.load_model = True
|
2026-06-06 14:40:07 +08:00
|
|
|
|
self.zmqServer.broadcast_message('paradigm', 1)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
|
|
|
|
|
'''训练阶段采集数据'''
|
|
|
|
|
|
if self.zmqServer.state_mode == 'train': # 训练状态
|
|
|
|
|
|
if self.zmqServer.StartTrain:
|
|
|
|
|
|
self.currentLabel = self.zmqServer.currentLabel
|
|
|
|
|
|
self.zmqServer.StartTrain = False
|
2026-06-07 11:05:24 +08:00
|
|
|
|
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.train_epoch[1] \
|
2026-06-07 11:05:24 +08:00
|
|
|
|
+ self.zmqServer.event_inner_idx:
|
2026-06-05 09:34:29 +08:00
|
|
|
|
time.sleep(0.0001)
|
|
|
|
|
|
return
|
2026-06-06 14:40:07 +08:00
|
|
|
|
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
2026-06-07 11:05:24 +08:00
|
|
|
|
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
|
|
|
|
|
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
2026-06-05 09:34:29 +08:00
|
|
|
|
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
2026-06-07 11:05:24 +08:00
|
|
|
|
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
|
|
|
|
|
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
|
|
|
|
|
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
2026-06-05 09:34:29 +08:00
|
|
|
|
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
|
|
|
|
|
self.trainLabel, list) \
|
|
|
|
|
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
|
|
|
|
|
self.trainData.append(trainTrial)
|
|
|
|
|
|
self.trainLabel.append(self.currentLabel)
|
|
|
|
|
|
|
|
|
|
|
|
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
|
|
|
|
|
if self.load_model == False: # 模型尚未训练完成
|
|
|
|
|
|
time.sleep(0.01)
|
|
|
|
|
|
return
|
|
|
|
|
|
else: # 已有模型
|
|
|
|
|
|
if self.zmqServer.StartDecode:
|
|
|
|
|
|
self.zmqServer.StartDecode = False
|
|
|
|
|
|
now = datetime.now()
|
|
|
|
|
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
|
|
|
|
|
print('启动预测 ', formatted_time)
|
|
|
|
|
|
|
2026-06-07 11:05:24 +08:00
|
|
|
|
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.interval_epoch[1] \
|
2026-06-07 11:05:24 +08:00
|
|
|
|
+ self.zmqServer.event_inner_idx:
|
2026-06-05 09:34:29 +08:00
|
|
|
|
time.sleep(0.0001)
|
|
|
|
|
|
return
|
2026-06-07 11:05:24 +08:00
|
|
|
|
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
|
|
|
|
|
|
print('取出的: ', data.shape, 'event: ', data[-2, self.zmqServer.event_inner_idx])
|
2026-06-05 09:34:29 +08:00
|
|
|
|
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
|
|
|
|
|
data = data[:,
|
2026-06-07 11:05:24 +08:00
|
|
|
|
self.zmqServer.event_inner_idx + self.interval_epoch[
|
|
|
|
|
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
2026-06-05 09:34:29 +08:00
|
|
|
|
pad_eeg_test = np.zeros(
|
2026-06-06 14:40:07 +08:00
|
|
|
|
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
|
|
|
|
|
|
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
|
2026-06-05 09:34:29 +08:00
|
|
|
|
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
|
|
|
|
|
if isinstance(choosenNum, np.ndarray):
|
|
|
|
|
|
choosenNum = choosenNum[0]
|
|
|
|
|
|
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
|
|
|
|
|
|
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
|
2026-06-07 11:05:24 +08:00
|
|
|
|
self.zmqServer.broadcast_message('result', int(choosenNum))
|
2026-06-05 09:34:29 +08:00
|
|
|
|
print('发送给界面完成。')
|
|
|
|
|
|
else: # 休息状态
|
2026-06-07 11:05:24 +08:00
|
|
|
|
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
2026-06-06 14:40:07 +08:00
|
|
|
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
2026-06-05 09:34:29 +08:00
|
|
|
|
time.sleep(0.005)
|
|
|
|
|
|
return
|
2026-06-07 11:05:24 +08:00
|
|
|
|
self.zmqServer.paradigmBuffer.getData(25)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
|
|
|
|
|
def decoder_MI(self):
|
|
|
|
|
|
'''模型训练'''
|
|
|
|
|
|
if self.train_started == False and all(
|
|
|
|
|
|
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
|
2026-06-07 11:05:24 +08:00
|
|
|
|
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.train_started = True
|
|
|
|
|
|
self.trainData = np.array(self.trainData)
|
|
|
|
|
|
self.trainLabel = np.array(self.trainLabel) + 1
|
|
|
|
|
|
# print('训练集:',np.shape(self.trainData), (self.trainLabel))
|
|
|
|
|
|
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
|
|
|
|
|
|
p.start()
|
|
|
|
|
|
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
|
|
|
|
|
|
'n_chan': self.n_chan})
|
|
|
|
|
|
|
|
|
|
|
|
'''检查模型是否训练完成,调用'''
|
|
|
|
|
|
if self.load_model == False and self.train_started == True:
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = self.mp_result_queue.get_nowait()
|
|
|
|
|
|
if result['status'] == 'success':
|
|
|
|
|
|
print("模型训练完成,加载新模型")
|
|
|
|
|
|
# 调用模型
|
|
|
|
|
|
self.model = torch.load(self.modelPath, weights_only=False)
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
# 模型预热
|
|
|
|
|
|
warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000))
|
|
|
|
|
|
warmup_data = torch.from_numpy(warmup_data)
|
|
|
|
|
|
warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor))
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
_ = self.model(warmup_data)
|
|
|
|
|
|
self.load_model = True
|
2026-06-07 11:05:24 +08:00
|
|
|
|
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
|
2026-06-05 09:34:29 +08:00
|
|
|
|
else:
|
|
|
|
|
|
print("训练失败:", result['msg'])
|
|
|
|
|
|
except Empty:
|
|
|
|
|
|
pass # 还没完成
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print('模型调用失败: ', e)
|
|
|
|
|
|
|
|
|
|
|
|
'''训练阶段采集数据'''
|
|
|
|
|
|
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
|
|
|
|
|
|
if self.zmqServer.StartTrain:
|
|
|
|
|
|
self.currentLabel = self.zmqServer.currentLabel
|
|
|
|
|
|
self.zmqServer.StartTrain = False
|
2026-06-07 11:05:24 +08:00
|
|
|
|
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.interval_epoch[1] \
|
2026-06-07 11:05:24 +08:00
|
|
|
|
+ self.zmqServer.event_inner_idx:
|
2026-06-05 09:34:29 +08:00
|
|
|
|
time.sleep(0.0001)
|
|
|
|
|
|
return
|
2026-06-06 14:40:07 +08:00
|
|
|
|
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
2026-06-07 11:05:24 +08:00
|
|
|
|
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
|
|
|
|
|
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.zmqServer.event_inner_idx])
|
2026-06-05 09:34:29 +08:00
|
|
|
|
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
2026-06-07 11:05:24 +08:00
|
|
|
|
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
|
|
|
|
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
|
|
|
|
|
print('trial: ', self.zmqServer.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
2026-06-05 09:34:29 +08:00
|
|
|
|
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
|
|
|
|
|
list) \
|
|
|
|
|
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
|
|
|
|
|
self.trainData.append(trainTrial)
|
|
|
|
|
|
self.trainLabel.append(self.currentLabel)
|
|
|
|
|
|
print('训练集:', np.shape(self.trainData))
|
2026-06-07 11:05:24 +08:00
|
|
|
|
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
|
|
|
|
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.plotLabel.append(self.currentLabel)
|
|
|
|
|
|
|
|
|
|
|
|
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
|
|
|
|
|
if self.zmqServer.StartDecode:
|
|
|
|
|
|
self.zmqServer.StartDecode = False
|
|
|
|
|
|
now = datetime.now()
|
|
|
|
|
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
|
|
|
|
|
print('启动预测 ', formatted_time)
|
|
|
|
|
|
|
2026-06-07 11:05:24 +08:00
|
|
|
|
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.interval_epoch[1] \
|
2026-06-07 11:05:24 +08:00
|
|
|
|
+ self.zmqServer.event_inner_idx:
|
2026-06-05 09:34:29 +08:00
|
|
|
|
time.sleep(0.0001)
|
|
|
|
|
|
return
|
2026-06-07 11:05:24 +08:00
|
|
|
|
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
|
|
|
|
|
|
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.zmqServer.event_inner_idx])
|
2026-06-05 09:34:29 +08:00
|
|
|
|
start = time.time()
|
|
|
|
|
|
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
|
|
|
|
|
data = data[:,
|
2026-06-07 11:05:24 +08:00
|
|
|
|
self.zmqServer.event_inner_idx + self.interval_epoch[
|
|
|
|
|
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.plotData.append(
|
2026-06-07 11:05:24 +08:00
|
|
|
|
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
|
|
|
|
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
|
|
|
|
|
test_data = data[np.newaxis, np.newaxis, :, :]
|
|
|
|
|
|
test_data = torch.from_numpy(test_data)
|
|
|
|
|
|
test_data = Variable(test_data.type(torch.cuda.FloatTensor))
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
Cls = self.model(test_data)
|
|
|
|
|
|
y_pred = torch.max(Cls, 1)[1]
|
|
|
|
|
|
self.plotLabel.append(int(y_pred.item()))
|
|
|
|
|
|
print('运动意图识别: ', y_pred)
|
2026-06-07 11:05:24 +08:00
|
|
|
|
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
|
2026-06-05 09:34:29 +08:00
|
|
|
|
end = time.time()
|
|
|
|
|
|
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
|
|
|
|
|
else: # 休息状态
|
2026-06-07 11:05:24 +08:00
|
|
|
|
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
2026-06-06 14:40:07 +08:00
|
|
|
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
2026-06-05 09:34:29 +08:00
|
|
|
|
time.sleep(0.005)
|
|
|
|
|
|
return
|
2026-06-07 11:05:24 +08:00
|
|
|
|
self.zmqServer.paradigmBuffer.getData(25)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
2026-06-06 14:40:07 +08:00
|
|
|
|
# def decoder_concentration(self):
|
|
|
|
|
|
# if self.zmqServer.state_mode == 'predict':
|
|
|
|
|
|
# if self.zmqServer.StartDecode:
|
|
|
|
|
|
# self.zmqServer.StartDecode = False
|
|
|
|
|
|
# self.thread_data_server.ResetAll()
|
|
|
|
|
|
# now = datetime.now()
|
|
|
|
|
|
# formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
|
|
|
|
|
# print('启动专注力预测 ', formatted_time)
|
|
|
|
|
|
# if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
|
|
|
|
|
|
# time.sleep(0.005)
|
|
|
|
|
|
# return
|
|
|
|
|
|
# if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
|
|
|
|
|
# return
|
|
|
|
|
|
# data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
|
|
|
|
|
|
# result = self.calculate.queueOpt(data)
|
|
|
|
|
|
# if result is not None:
|
|
|
|
|
|
# self.zmqClient.send_to_all('result', int(result))
|
|
|
|
|
|
# else: # 休息状态
|
|
|
|
|
|
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
|
|
|
|
|
# if self.thread_data_server.GetDataLenCount() < 25:
|
|
|
|
|
|
# time.sleep(0.005)
|
|
|
|
|
|
# return
|
|
|
|
|
|
# self.thread_data_server.getData(25)
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stop(self):
|
|
|
|
|
|
'''
|
|
|
|
|
|
停止运行
|
|
|
|
|
|
@return:
|
|
|
|
|
|
'''
|
|
|
|
|
|
self.zmqServer.stop()
|
2026-06-08 11:56:42 +08:00
|
|
|
|
self.sliding_filter.stop()
|
2026-06-05 09:34:29 +08:00
|
|
|
|
self.Runing=False
|
|
|
|
|
|
|
|
|
|
|
|
def reset_state(self):
|
|
|
|
|
|
"""清空解码器状态和缓存数据"""
|
|
|
|
|
|
# 重置设备层缓存
|
2026-06-06 17:08:09 +08:00
|
|
|
|
self.zmqServer.reset_state()
|
2026-06-05 09:34:29 +08:00
|
|
|
|
|
|
|
|
|
|
# 重置解码状态
|
|
|
|
|
|
self.decodingSteps = 0
|
|
|
|
|
|
self.calculateCount = 0
|
|
|
|
|
|
|
|
|
|
|
|
# 重置训练数据
|
|
|
|
|
|
self.plotData = []
|
|
|
|
|
|
self.plotLabel = []
|
|
|
|
|
|
self.trainData = []
|
|
|
|
|
|
self.trainLabel = []
|
|
|
|
|
|
self.currentLabel = -1
|
|
|
|
|
|
self.train_started = False
|
|
|
|
|
|
self.load_model = False
|
|
|
|
|
|
|
|
|
|
|
|
# 重置多进程队列,确保切换 decoder 时旧数据不会泄漏到新队列
|
|
|
|
|
|
if hasattr(self, 'mp_data_queue'):
|
|
|
|
|
|
while not self.mp_data_queue.empty():
|
|
|
|
|
|
try: self.mp_data_queue.get_nowait()
|
|
|
|
|
|
except Empty: pass
|
|
|
|
|
|
if hasattr(self, 'mp_result_queue'):
|
|
|
|
|
|
while not self.mp_result_queue.empty():
|
|
|
|
|
|
try: self.mp_result_queue.get_nowait()
|
|
|
|
|
|
except Empty: pass
|