Files
bci_algo/Decoder.py

521 lines
27 KiB
Python
Raw Normal View History

2026-06-05 09:34:29 +08:00
import ast
2026-06-06 14:51:00 +08:00
import glob
2026-06-06 14:40:07 +08:00
import os
import sys
2026-06-05 09:34:29 +08:00
import threading
from datetime import datetime
import multiprocessing as mp
import numpy as np
import time
import torch
from queue import Empty
import queue
2026-06-05 09:34:29 +08:00
from scipy import signal
from torch.autograd import Variable
2026-06-06 14:40:07 +08:00
# from Device.SunnyLinker import SunnyLinker64
2026-06-05 09:34:29 +08:00
from SSMVEP.algorithm.tdca import TDCA
from SSMVEP.algorithm.base import generate_cca_references
2026-06-10 16:04:02 +08:00
# from concentration.algorithm.calculate_focus import Calculate
# from blinkdetection.algorithm.eye_detection import blink_detection
2026-06-05 09:34:29 +08:00
from Zmq.zmqServer import zmqServer
from Zmq.zmqClient import zmqClient
from MI.Algorithm.conformer_2class import onlineTrain
from PubLibrary.InifileHelper import IniRead
2026-06-06 14:40:07 +08:00
from logs.log import algo_log
2026-06-05 09:34:29 +08:00
from SSVEP.dwfbcca import FbccaDw
2026-06-06 14:40:07 +08:00
# from Tools.plot_MI_EEG import plotMain
2026-06-05 09:34:29 +08:00
from collections import deque
2026-06-07 11:05:24 +08:00
from Zmq.filterProcess import SlidingFilter
2026-06-06 14:40:07 +08:00
2026-06-09 14:23:25 +08:00
save_train_data = int(IniRead('system', 'save_train_data', 0))
2026-06-06 14:40:07 +08:00
def get_root_path():
"""
Nuitka 打包专用获取程序根目录.py .exe 所在目录
"""
if getattr(sys, 'frozen', False):
# 打包后:返回 exe 所在目录
return os.path.dirname(sys.executable)
else:
# 开发时:返回 py 文件所在目录
return os.path.dirname(os.path.abspath(__file__))
MODEL_FOLDER = "online_Models"
class Decoder_main(threading.Thread):
def __init__(self, device_info=None):
2026-06-05 09:34:29 +08:00
threading.Thread.__init__(self)
2026-06-08 11:56:42 +08:00
self.device_info = device_info
2026-06-05 09:34:29 +08:00
self.Runing=True
self.decoder = None
self.decoder_class = None #解码器类别
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
2026-06-08 11:56:42 +08:00
self.zmqServer = zmqServer(device_info=self.device_info)
self.zmqServer.start() # 启动ZMQ接收线程
2026-06-08 15:23:47 +08:00
2026-06-08 11:56:42 +08:00
self.sliding_filter = SlidingFilter(
ring_buffer=self.zmqServer.filterBuffer,
n_chan=self.zmqServer.device_info['channel_nums'],
srate=self.zmqServer.device_info['sample_rate']
)
2026-06-06 14:40:07 +08:00
2026-06-08 11:56:42 +08:00
# 注册滤波结果回调(示例:打印数据形状)
2026-06-08 15:23:47 +08:00
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
2026-06-10 17:52:02 +08:00
# 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机
2026-06-12 11:32:39 +08:00
self.sliding_filter.set_beta_broadcast_callback(lambda v: self.zmqServer.broadcast_message('beta_psd', v))
2026-06-05 09:34:29 +08:00
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples)
energy = np.mean(np.var(data, axis=1)) # 各通道方差均值
if energy > threshold:
return False
return True
def init_Decoder(self,decoder_class):
'''
初始化解码器
:param decoder_class: 'ssvep' or 'ssmvep' or 'mi' or 'concentration' or ''
:return:
'''
self.decoder_class = decoder_class
2026-06-10 07:48:43 +08:00
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
2026-06-05 09:34:29 +08:00
self.n_chan = 8
2026-06-06 14:40:07 +08:00
# self.thread_data_server.interval_inited = False
2026-06-05 09:34:29 +08:00
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
self.ListFreq = self.zmqServer.targetFreqs
self.num_target = len(self.ListFreq)
if self.num_target == 0:
return
# 初始化对象 二代算法
2026-06-06 14:40:07 +08:00
self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
2026-06-05 09:34:29 +08:00
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
# frequence band
self.dw.filterFrequenceBank()
self.dw.setNotchFilterPara()
self.calculateCount = 0
2026-06-06 14:40:07 +08:00
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
2026-06-05 09:34:29 +08:00
self.dw.filterInit()
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
elif decoder_class == 'ssmvep':
2026-06-06 14:40:07 +08:00
self.zmqServer.interval_init(decoder_class)
2026-06-05 09:34:29 +08:00
self.n_chan = 8
2026-06-09 10:57:28 +08:00
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
2026-06-05 09:34:29 +08:00
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 10 # 单类别数量
self.num_target = 2 # 分类目标数目
self.list_freqs = np.array([8, 9]) # 刺激频率
self.list_phase = np.array([0, 0]) # 相位
self.tdca = TDCA(padding_len=5, n_components=1)
2026-06-06 14:40:07 +08:00
self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
2026-06-05 09:34:29 +08:00
phases=self.list_phase, n_harmonics=5)
self.parameter_init(5,45)
elif decoder_class == 'mi' or decoder_class == 'ma':
2026-06-06 14:40:07 +08:00
self.zmqServer.interval_init(decoder_class)
2026-06-05 09:34:29 +08:00
self.n_chan = 21
2026-06-10 15:18:22 +08:00
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度4s,# 精确到小数点后6位
2026-06-05 09:34:29 +08:00
self.single_train = 40 # 单类别数量
self.num_target = 2 # 分类目标数目
self.parameter_init(8, 30)
2026-06-06 09:16:49 +08:00
# elif decoder_class == 'concentration':
# self.thread_data_server.interval_inited = False
# self.n_chan = 6
# self.win_len = 10
# self.win_step = 1
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
2026-06-06 14:40:07 +08:00
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
2026-06-06 09:16:49 +08:00
# self.interval_epoch = [0, 1]
# self.parameter_init(2, 40)
# # self.eegQueue moved to Calculate class
2026-06-05 09:34:29 +08:00
2026-06-06 09:16:49 +08:00
# elif decoder_class == 'blink':
# self.n_chan = 2
# self.l_freq = 0.1 # 带通滤波器低频截止
# self.h_freq = 8.0 # 带通滤波器高频截止
# self.total_samples = 0 # 总采样点数
# self.window_ms = 600 # 检测窗口大小 (ms)
# self.step_ms = 100 # 滑动步长 (ms)
2026-06-06 14:40:07 +08:00
# self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
# self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
2026-06-06 09:16:49 +08:00
# self.buffer_size = self.window_samples + self.step_samples * 5
# self.fp1_buffer = deque(maxlen=self.buffer_size)
# self.fp2_buffer = deque(maxlen=self.buffer_size)
# self.sample_counter = 0
# # 预计算滤波器系数,避免在循环中重复设计
# self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink'))
# self.blink_count = 0 # 单次眨眼的次数
# self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引)
# self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳
# self.double_blink_count = 0 # 连续两次眨眼的次数
# self.double_blink_events = [] # 连续眨眼事件记录
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
# self.blink_events = []
2026-06-06 14:40:07 +08:00
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
2026-06-05 09:34:29 +08:00
def parameter_init(self,bandPass_low,bandPass_high):
2026-06-13 16:49:29 +08:00
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息 ssmvep [50, 550]
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch ssmevep [50, 575]
2026-06-05 09:34:29 +08:00
self.trainData = [] #训练数据
self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据
self.plotLabel = [] #报告分析标签
self.currentLabel = -1 #刺激界面当前显示的训练标签
self.train_started = False #是否开始训练模型
self.load_model = False # 调用模型是否完成的标志
2026-06-06 14:40:07 +08:00
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波250是采样率30是质量因子
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
2026-06-06 15:13:23 +08:00
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
2026-06-06 14:51:00 +08:00
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
os.remove(old_pth)
2026-06-07 11:05:24 +08:00
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
2026-06-05 09:34:29 +08:00
self.modelPath = ''.join([filePath, fileName, '.pth'])
2026-06-06 14:51:00 +08:00
self.mp_data_queue = mp.Queue()
self.mp_result_queue = mp.Queue()
2026-06-05 09:34:29 +08:00
def preprocess(self, signal_data):
# # 计算每行的平均值
row_means = np.mean(signal_data, axis=-1, keepdims=True)
# 对每一行去均值
signal_data = signal_data - row_means
signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波
signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波
return signal_data
def run(self):
while self.Runing:
2026-06-08 11:56:42 +08:00
# 当滤波数据大于5秒时启动滤波线程
2026-06-08 17:06:27 +08:00
if not self.sliding_filter.is_alive() and self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5:
2026-06-08 15:23:47 +08:00
algo_log("启动滤波线程", level="DEBUG")
2026-06-08 11:56:42 +08:00
self.sliding_filter.start()
2026-06-05 09:34:29 +08:00
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
2026-06-08 15:23:47 +08:00
algo_log(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}", level="DEBUG")
2026-06-05 09:34:29 +08:00
self.zmqServer.decoder_switch = False
self.zmqServer.changeTarget = False
self.reset_state() # 切换前先统一清理旧状态
self.init_Decoder(self.zmqServer.decoder_class)
# 同步信息
if self.zmqServer.state_mode == 'sync':
2026-06-06 14:40:07 +08:00
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
2026-06-05 09:34:29 +08:00
self.zmqServer.state_mode = 'rest'
2026-06-06 14:40:07 +08:00
2026-06-05 09:34:29 +08:00
try:
2026-06-12 13:56:48 +08:00
if self.zmqServer.open_Impedance:
time.sleep(0.005)
continue
2026-06-05 09:34:29 +08:00
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.decoder_SSVEP()
elif self.decoder_class == 'ssmvep':
self.decoder_SSMVEP()
elif self.decoder_class == 'mi':
self.decoder_MI()
2026-06-07 11:05:24 +08:00
else:
2026-06-09 14:23:25 +08:00
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005)
2026-06-12 13:56:48 +08:00
continue
2026-06-09 14:23:25 +08:00
self.zmqServer.paradigmBuffer.getData(25)
2026-06-05 09:34:29 +08:00
except Exception as e:
2026-06-06 14:40:07 +08:00
algo_log(f"Decoder Loop Error: {e}")
2026-06-05 09:34:29 +08:00
time.sleep(0.1) # Prevent CPU spin if error is persistent
def decoder_SSVEP(self):
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
self.decodingSteps = 1
2026-06-06 17:08:09 +08:00
self.zmqServer.paradigmBuffer.resetAllPara()
2026-06-09 14:23:25 +08:00
algo_log('启动SSVEP预测', level="DEBUG")
2026-06-06 14:40:07 +08:00
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
2026-06-05 09:34:29 +08:00
time.sleep(0.005)
return
2026-06-06 17:08:09 +08:00
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
2026-06-05 09:34:29 +08:00
return
2026-06-06 14:40:07 +08:00
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
2026-06-13 10:06:29 +08:00
# algo_log(f"SSVEP取出的{data.shape}, data = {data[:, :10]}", level="DEBUG")
2026-06-05 09:34:29 +08:00
data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
self.dw.warmFilter(data) # 预热
self.decodingSteps = 2
2026-06-09 14:23:25 +08:00
algo_log('SSVEP预热数据完成。开始预测', level="DEBUG")
2026-06-05 09:34:29 +08:00
return
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
self.calculateCount += 1
if choosenNum != -1 and self.is_valid_signal(data):
self.decodingSteps = 3
2026-06-09 14:23:25 +08:00
algo_log('SSVEP预测结果' + str(choosenNum) + ',计算次数:' + str(self.calculateCount), level="DEBUG")
2026-06-05 09:34:29 +08:00
self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息
2026-06-06 14:40:07 +08:00
self.zmqServer.broadcast_message('result', int(choosenNum))
2026-06-05 09:34:29 +08:00
self.decodingSteps = 0
2026-06-09 14:23:25 +08:00
algo_log('SSVEP发送给界面完成。', level="DEBUG")
2026-06-05 09:34:29 +08:00
def decoder_SSMVEP(self):
'''模型训练'''
if self.load_model == False and all(
2026-06-09 19:30:27 +08:00
self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练完成
2026-06-05 09:34:29 +08:00
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel)
2026-06-09 14:23:25 +08:00
algo_log(f"开始SSMVEP模型训练数据形状{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG")
if save_train_data == 1:
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f"{now_str}.npz"
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
2026-06-05 09:34:29 +08:00
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
2026-06-09 14:23:25 +08:00
algo_log(f"SSMVEP模型训练完成时间{formatted_time}", level="DEBUG")
2026-06-05 09:34:29 +08:00
self.load_model = True
2026-06-06 14:40:07 +08:00
self.zmqServer.broadcast_message('paradigm', 1)
2026-06-05 09:34:29 +08:00
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
time.sleep(0.0001)
return
trainTrial = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
self.currentLabel = epoch_payload['currentLabel']
data_length = epoch_payload['data_length']
if trainTrial is None or data_length < self.train_epoch[1] + event_inner_idx:
algo_log(f"SSMVEP epoch数据长度不足: {data_length}, 跳过", level="WARNING")
2026-06-05 09:34:29 +08:00
return
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, event_inner_idx + self.train_epoch[
0]:event_inner_idx + self.train_epoch[1]]
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG")
2026-06-05 09:34:29 +08:00
elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成
time.sleep(0.01)
return
else: # 已有模型
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
2026-06-09 14:23:25 +08:00
algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
2026-06-05 09:34:29 +08:00
time.sleep(0.0001)
return
data = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
data_length = epoch_payload['data_length']
if data is None or data_length < self.interval_epoch[1] + event_inner_idx:
algo_log(f"SSMVEP predict epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{data.shape}, event: {data[-2, event_inner_idx]}", level="DEBUG")
2026-06-05 09:34:29 +08:00
data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:,
event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]]
2026-06-05 09:34:29 +08:00
pad_eeg_test = np.zeros(
2026-06-06 14:40:07 +08:00
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
2026-06-05 09:34:29 +08:00
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0]
2026-06-09 14:23:25 +08:00
algo_log(f"结果:{choosenNum}, rho: {sorted(features_2[0])[-1] - sorted(features_2[0])[-2]}", level="DEBUG")
2026-06-07 11:05:24 +08:00
self.zmqServer.broadcast_message('result', int(choosenNum))
2026-06-09 14:23:25 +08:00
algo_log("SSMVEP发送给界面完成。", level="DEBUG")
2026-06-05 09:34:29 +08:00
else: # 休息状态
2026-06-09 14:23:25 +08:00
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.zmqServer.paradigmBuffer.getData(25)
2026-06-05 09:34:29 +08:00
def decoder_MI(self):
'''模型训练'''
if self.train_started == False and all(
2026-06-09 19:30:27 +08:00
self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练
2026-06-07 11:05:24 +08:00
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
2026-06-05 09:34:29 +08:00
self.train_started = True
self.trainData = np.array(self.trainData)
2026-06-10 10:05:08 +08:00
self.trainLabel = np.array(self.trainLabel)
2026-06-09 14:23:25 +08:00
algo_log(f"MI开始训练训练集{np.shape(self.trainData)}标签shape{np.shape(self.trainLabel)}", level="DEBUG")
if save_train_data == 1:
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f"{now_str}.npz"
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
2026-06-05 09:34:29 +08:00
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
p.start()
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
'n_chan': self.n_chan})
'''检查模型是否训练完成,调用'''
if self.load_model == False and self.train_started == True:
try:
result = self.mp_result_queue.get_nowait()
if result['status'] == 'success':
2026-06-09 14:23:25 +08:00
algo_log("MI模型训练完成加载新模型", level="DEBUG")
2026-06-05 09:34:29 +08:00
# 调用模型
self.model = torch.load(self.modelPath, weights_only=False)
self.model.eval()
# 模型预热
warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000))
warmup_data = torch.from_numpy(warmup_data)
warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
_ = self.model(warmup_data)
self.load_model = True
2026-06-07 11:05:24 +08:00
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
2026-06-05 09:34:29 +08:00
else:
2026-06-09 14:23:25 +08:00
algo_log("MI训练失败: " + result['msg'], level="DEBUG")
2026-06-05 09:34:29 +08:00
except Empty:
pass # 还没完成
except Exception as e:
2026-06-09 14:23:25 +08:00
algo_log("MI模型训练失败: " + str(e), level="DEBUG")
2026-06-05 09:34:29 +08:00
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
time.sleep(0.0001)
return
originalTrial = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
self.currentLabel = epoch_payload['currentLabel']
data_length = epoch_payload['data_length']
if originalTrial is None or data_length < self.zmqServer.train_epoch[1] + event_inner_idx:
algo_log(f"epoch数据长度不足: {data_length}, 跳过", level="WARNING")
2026-06-05 09:34:29 +08:00
return
algo_log(f"训练队列数据:{data_length}", level="DEBUG")
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]]
# algo_log(f"trial: {event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
self.plotData.append(originalTrial[:self.n_chan, event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
2026-06-05 09:34:29 +08:00
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
2026-06-09 14:23:25 +08:00
algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
2026-06-05 09:34:29 +08:00
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
2026-06-14 10:25:56 +08:00
time.sleep(0.001)
2026-06-05 09:34:29 +08:00
return
originalData = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
data_length = epoch_payload['data_length']
if originalData is None or data_length < self.interval_epoch[1] + event_inner_idx:
algo_log(f"predict epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, event_inner_idx]}", level="DEBUG")
2026-06-05 09:34:29 +08:00
start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:,
event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]]
2026-06-05 09:34:29 +08:00
self.plotData.append(
originalData[:self.n_chan, event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]])
2026-06-05 09:34:29 +08:00
test_data = data[np.newaxis, np.newaxis, :, :]
test_data = torch.from_numpy(test_data)
test_data = Variable(test_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
Cls = self.model(test_data)
y_pred = torch.max(Cls, 1)[1]
self.plotLabel.append(int(y_pred.item()))
2026-06-09 14:23:25 +08:00
algo_log(f"MI运动意图识别: {y_pred}")
2026-06-11 11:06:59 +08:00
self.zmqServer.broadcast_message('result', int(y_pred.item()))
2026-06-05 09:34:29 +08:00
end = time.time()
2026-06-10 16:04:02 +08:00
algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。')
2026-06-05 09:34:29 +08:00
else: # 休息状态
2026-06-09 14:23:25 +08:00
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005)
return
self.zmqServer.paradigmBuffer.getData(25)
2026-06-05 09:34:29 +08:00
2026-06-06 14:40:07 +08:00
# def decoder_concentration(self):
# if self.zmqServer.state_mode == 'predict':
# if self.zmqServer.StartDecode:
# self.zmqServer.StartDecode = False
# self.thread_data_server.ResetAll()
# now = datetime.now()
# formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
# print('启动专注力预测 ', formatted_time)
# if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
# time.sleep(0.005)
# return
# if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
# return
# data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
# result = self.calculate.queueOpt(data)
# if result is not None:
# self.zmqClient.send_to_all('result', int(result))
# else: # 休息状态
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
# if self.thread_data_server.GetDataLenCount() < 25:
# time.sleep(0.005)
# return
# self.thread_data_server.getData(25)
2026-06-05 09:34:29 +08:00
def stop(self):
'''
停止运行
@return:
'''
self.zmqServer.stop()
2026-06-08 11:56:42 +08:00
self.sliding_filter.stop()
2026-06-05 09:34:29 +08:00
self.Runing=False
def reset_state(self):
"""清空解码器状态和缓存数据"""
# 重置设备层缓存
2026-06-06 17:08:09 +08:00
self.zmqServer.reset_state()
2026-06-05 09:34:29 +08:00
# 重置解码状态
self.decodingSteps = 0
self.calculateCount = 0
# 重置训练数据
self.plotData = []
self.plotLabel = []
self.trainData = []
self.trainLabel = []
self.currentLabel = -1
self.train_started = False
self.load_model = False
# 重置多进程队列,确保切换 decoder 时旧数据不会泄漏到新队列
if hasattr(self, 'mp_data_queue'):
while not self.mp_data_queue.empty():
try: self.mp_data_queue.get_nowait()
except Empty: pass
if hasattr(self, 'mp_result_queue'):
while not self.mp_result_queue.empty():
try: self.mp_result_queue.get_nowait()
except Empty: pass