Files
bci_algo/Decoder.py

500 lines
26 KiB
Python
Raw Normal View History

2026-06-05 09:34:29 +08:00
import ast
2026-06-06 14:51:00 +08:00
import glob
2026-06-06 14:40:07 +08:00
import os
import sys
2026-06-05 09:34:29 +08:00
import threading
from datetime import datetime
import multiprocessing as mp
import numpy as np
import time
import torch
from queue import Empty
from scipy import signal
from torch.autograd import Variable
2026-06-06 14:40:07 +08:00
# from Device.SunnyLinker import SunnyLinker64
2026-06-05 09:34:29 +08:00
from SSMVEP.algorithm.tdca import TDCA
from SSMVEP.algorithm.base import generate_cca_references
from concentration.algorithm.calculate_focus import Calculate
from blinkdetection.algorithm.eye_detection import blink_detection
from Zmq.zmqServer import zmqServer
from Zmq.zmqClient import zmqClient
from MI.Algorithm.conformer_2class import onlineTrain
from PubLibrary.InifileHelper import IniRead
2026-06-06 14:40:07 +08:00
from logs.log import algo_log
2026-06-05 09:34:29 +08:00
from SSVEP.dwfbcca import FbccaDw
2026-06-06 14:40:07 +08:00
# from Tools.plot_MI_EEG import plotMain
2026-06-05 09:34:29 +08:00
from collections import deque
2026-06-07 11:05:24 +08:00
from Zmq.filterProcess import SlidingFilter
2026-06-06 14:40:07 +08:00
def get_root_path():
"""
Nuitka 打包专用获取程序根目录.py .exe 所在目录
"""
if getattr(sys, 'frozen', False):
# 打包后:返回 exe 所在目录
return os.path.dirname(sys.executable)
else:
# 开发时:返回 py 文件所在目录
return os.path.dirname(os.path.abspath(__file__))
MODEL_FOLDER = "online_Models"
class Decoder_main(threading.Thread):
def __init__(self, device_info=None):
2026-06-05 09:34:29 +08:00
threading.Thread.__init__(self)
2026-06-08 11:56:42 +08:00
self.device_info = device_info
2026-06-05 09:34:29 +08:00
self.Runing=True
self.decoder = None
self.decoder_class = None #解码器类别
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
2026-06-08 11:56:42 +08:00
self.zmqServer = zmqServer(device_info=self.device_info)
self.zmqServer.start() # 启动ZMQ接收线程
2026-06-08 15:23:47 +08:00
2026-06-08 11:56:42 +08:00
self.sliding_filter = SlidingFilter(
ring_buffer=self.zmqServer.filterBuffer,
n_chan=self.zmqServer.device_info['channel_nums'],
srate=self.zmqServer.device_info['sample_rate']
)
2026-06-06 14:40:07 +08:00
2026-06-08 11:56:42 +08:00
# 注册滤波结果回调(示例:打印数据形状)
2026-06-08 15:23:47 +08:00
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
2026-06-05 09:34:29 +08:00
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples)
energy = np.mean(np.var(data, axis=1)) # 各通道方差均值
if energy > threshold:
return False
return True
def init_Decoder(self,decoder_class):
'''
初始化解码器
:param decoder_class: 'ssvep' or 'ssmvep' or 'mi' or 'concentration' or ''
:return:
'''
self.decoder_class = decoder_class
if decoder_class == 'ssvep' or decoder_class == 'pvs':
self.n_chan = 8
2026-06-06 14:40:07 +08:00
# self.thread_data_server.interval_inited = False
2026-06-05 09:34:29 +08:00
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
self.ListFreq = self.zmqServer.targetFreqs
self.num_target = len(self.ListFreq)
if self.num_target == 0:
return
# 初始化对象 二代算法
2026-06-06 14:40:07 +08:00
self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
2026-06-05 09:34:29 +08:00
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
# frequence band
self.dw.filterFrequenceBank()
self.dw.setNotchFilterPara()
self.calculateCount = 0
2026-06-06 14:40:07 +08:00
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
2026-06-05 09:34:29 +08:00
self.dw.filterInit()
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
elif decoder_class == 'ssmvep':
2026-06-06 14:40:07 +08:00
self.zmqServer.interval_init(decoder_class)
2026-06-05 09:34:29 +08:00
self.n_chan = 8
2026-06-09 10:57:28 +08:00
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
2026-06-05 09:34:29 +08:00
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 10 # 单类别数量
self.num_target = 2 # 分类目标数目
self.list_freqs = np.array([8, 9]) # 刺激频率
self.list_phase = np.array([0, 0]) # 相位
self.tdca = TDCA(padding_len=5, n_components=1)
2026-06-06 14:40:07 +08:00
self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
2026-06-05 09:34:29 +08:00
phases=self.list_phase, n_harmonics=5)
self.parameter_init(5,45)
elif decoder_class == 'mi' or decoder_class == 'ma':
2026-06-06 14:40:07 +08:00
self.zmqServer.interval_init(decoder_class)
2026-06-05 09:34:29 +08:00
self.n_chan = 21
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 40 # 单类别数量
self.num_target = 2 # 分类目标数目
self.parameter_init(8, 30)
2026-06-06 09:16:49 +08:00
# elif decoder_class == 'concentration':
# self.thread_data_server.interval_inited = False
# self.n_chan = 6
# self.win_len = 10
# self.win_step = 1
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
2026-06-06 14:40:07 +08:00
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
2026-06-06 09:16:49 +08:00
# self.interval_epoch = [0, 1]
# self.parameter_init(2, 40)
# # self.eegQueue moved to Calculate class
2026-06-05 09:34:29 +08:00
2026-06-06 09:16:49 +08:00
# elif decoder_class == 'blink':
# self.n_chan = 2
# self.l_freq = 0.1 # 带通滤波器低频截止
# self.h_freq = 8.0 # 带通滤波器高频截止
# self.total_samples = 0 # 总采样点数
# self.window_ms = 600 # 检测窗口大小 (ms)
# self.step_ms = 100 # 滑动步长 (ms)
2026-06-06 14:40:07 +08:00
# self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
# self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
2026-06-06 09:16:49 +08:00
# self.buffer_size = self.window_samples + self.step_samples * 5
# self.fp1_buffer = deque(maxlen=self.buffer_size)
# self.fp2_buffer = deque(maxlen=self.buffer_size)
# self.sample_counter = 0
# # 预计算滤波器系数,避免在循环中重复设计
# self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink'))
# self.blink_count = 0 # 单次眨眼的次数
# self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引)
# self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳
# self.double_blink_count = 0 # 连续两次眨眼的次数
# self.double_blink_events = [] # 连续眨眼事件记录
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
# self.blink_events = []
2026-06-06 14:40:07 +08:00
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
2026-06-05 09:34:29 +08:00
def parameter_init(self,bandPass_low,bandPass_high):
2026-06-06 14:40:07 +08:00
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
2026-06-05 09:34:29 +08:00
self.trainData = [] #训练数据
self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据
self.plotLabel = [] #报告分析标签
self.currentLabel = -1 #刺激界面当前显示的训练标签
self.train_started = False #是否开始训练模型
self.load_model = False # 调用模型是否完成的标志
2026-06-06 14:40:07 +08:00
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波250是采样率30是质量因子
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
2026-06-06 15:13:23 +08:00
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
2026-06-06 14:51:00 +08:00
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
os.remove(old_pth)
2026-06-07 11:05:24 +08:00
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
2026-06-05 09:34:29 +08:00
self.modelPath = ''.join([filePath, fileName, '.pth'])
2026-06-06 14:51:00 +08:00
self.mp_data_queue = mp.Queue()
self.mp_result_queue = mp.Queue()
2026-06-05 09:34:29 +08:00
def preprocess(self, signal_data):
# # 计算每行的平均值
row_means = np.mean(signal_data, axis=-1, keepdims=True)
# 对每一行去均值
signal_data = signal_data - row_means
signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波
signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波
return signal_data
def run(self):
while self.Runing:
2026-06-08 11:56:42 +08:00
# 当滤波数据大于5秒时启动滤波线程
2026-06-08 17:06:27 +08:00
if not self.sliding_filter.is_alive() and self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5:
2026-06-08 15:23:47 +08:00
algo_log("启动滤波线程", level="DEBUG")
2026-06-08 11:56:42 +08:00
self.sliding_filter.start()
2026-06-05 09:34:29 +08:00
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
2026-06-08 15:23:47 +08:00
algo_log(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}", level="DEBUG")
2026-06-05 09:34:29 +08:00
self.zmqServer.decoder_switch = False
self.zmqServer.changeTarget = False
self.reset_state() # 切换前先统一清理旧状态
self.init_Decoder(self.zmqServer.decoder_class)
# 同步信息
if self.zmqServer.state_mode == 'sync':
2026-06-06 14:40:07 +08:00
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
2026-06-05 09:34:29 +08:00
self.zmqServer.state_mode = 'rest'
2026-06-06 14:40:07 +08:00
2026-06-05 09:34:29 +08:00
try:
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.decoder_SSVEP()
elif self.decoder_class == 'ssmvep':
self.decoder_SSMVEP()
elif self.decoder_class == 'mi':
self.decoder_MI()
2026-06-07 11:05:24 +08:00
else:
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005)
continue;
self.zmqServer.paradigmBuffer.getData(25)
2026-06-05 09:34:29 +08:00
except Exception as e:
2026-06-06 14:40:07 +08:00
algo_log(f"Decoder Loop Error: {e}")
2026-06-05 09:34:29 +08:00
time.sleep(0.1) # Prevent CPU spin if error is persistent
def decoder_SSVEP(self):
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
self.decodingSteps = 1
2026-06-06 17:08:09 +08:00
self.zmqServer.paradigmBuffer.resetAllPara()
2026-06-05 09:34:29 +08:00
print('启动预测')
2026-06-06 14:40:07 +08:00
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
2026-06-05 09:34:29 +08:00
time.sleep(0.005)
return
2026-06-06 17:08:09 +08:00
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
2026-06-05 09:34:29 +08:00
return
2026-06-06 14:40:07 +08:00
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
2026-06-05 09:34:29 +08:00
data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
self.dw.warmFilter(data) # 预热
self.decodingSteps = 2
print('预热数据完成。开始预测')
return
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
self.calculateCount += 1
if choosenNum != -1 and self.is_valid_signal(data):
self.decodingSteps = 3
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息
2026-06-06 14:40:07 +08:00
self.zmqServer.broadcast_message('result', int(choosenNum))
2026-06-05 09:34:29 +08:00
self.decodingSteps = 0
print('发送给界面完成。')
def decoder_SSMVEP(self):
'''模型训练'''
if self.load_model == False and all(
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel)
print(np.shape(self.trainData), (self.trainLabel))
# 保存多个数组到文件
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel)
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('模型训练完成', formatted_time)
self.load_model = True
2026-06-06 14:40:07 +08:00
self.zmqServer.broadcast_message('paradigm', 1)
2026-06-05 09:34:29 +08:00
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态
2026-06-09 10:57:28 +08:00
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.train_epoch[1] + self.zmqServer.event_inner_idx:
2026-06-05 09:34:29 +08:00
self.currentLabel = self.zmqServer.currentLabel
2026-06-09 10:57:28 +08:00
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
else:
2026-06-05 09:34:29 +08:00
time.sleep(0.0001)
return
elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成
time.sleep(0.01)
return
else: # 已有模型
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time)
2026-06-07 11:05:24 +08:00
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
2026-06-05 09:34:29 +08:00
self.interval_epoch[1] \
2026-06-07 11:05:24 +08:00
+ self.zmqServer.event_inner_idx:
2026-06-05 09:34:29 +08:00
time.sleep(0.0001)
return
2026-06-07 11:05:24 +08:00
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
print('取出的: ', data.shape, 'event: ', data[-2, self.zmqServer.event_inner_idx])
2026-06-05 09:34:29 +08:00
data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:,
2026-06-07 11:05:24 +08:00
self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
2026-06-05 09:34:29 +08:00
pad_eeg_test = np.zeros(
2026-06-06 14:40:07 +08:00
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
2026-06-05 09:34:29 +08:00
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0]
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
2026-06-07 11:05:24 +08:00
self.zmqServer.broadcast_message('result', int(choosenNum))
2026-06-05 09:34:29 +08:00
print('发送给界面完成。')
else: # 休息状态
2026-06-07 11:05:24 +08:00
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
2026-06-06 14:40:07 +08:00
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
2026-06-05 09:34:29 +08:00
time.sleep(0.005)
return
2026-06-07 11:05:24 +08:00
self.zmqServer.paradigmBuffer.getData(25)
2026-06-05 09:34:29 +08:00
def decoder_MI(self):
'''模型训练'''
if self.train_started == False and all(
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
2026-06-07 11:05:24 +08:00
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
2026-06-05 09:34:29 +08:00
self.train_started = True
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel) + 1
# print('训练集:',np.shape(self.trainData), (self.trainLabel))
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
p.start()
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
'n_chan': self.n_chan})
'''检查模型是否训练完成,调用'''
if self.load_model == False and self.train_started == True:
try:
result = self.mp_result_queue.get_nowait()
if result['status'] == 'success':
print("模型训练完成,加载新模型")
# 调用模型
self.model = torch.load(self.modelPath, weights_only=False)
self.model.eval()
# 模型预热
warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000))
warmup_data = torch.from_numpy(warmup_data)
warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
_ = self.model(warmup_data)
self.load_model = True
2026-06-07 11:05:24 +08:00
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
2026-06-05 09:34:29 +08:00
else:
print("训练失败:", result['msg'])
except Empty:
pass # 还没完成
except Exception as e:
print('模型调用失败: ', e)
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False
2026-06-07 11:05:24 +08:00
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
2026-06-05 09:34:29 +08:00
self.interval_epoch[1] \
2026-06-07 11:05:24 +08:00
+ self.zmqServer.event_inner_idx:
2026-06-05 09:34:29 +08:00
time.sleep(0.0001)
return
2026-06-06 14:40:07 +08:00
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
2026-06-07 11:05:24 +08:00
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.zmqServer.event_inner_idx])
2026-06-05 09:34:29 +08:00
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
2026-06-07 11:05:24 +08:00
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
print('trial: ', self.zmqServer.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
2026-06-05 09:34:29 +08:00
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
print('训练集:', np.shape(self.trainData))
2026-06-07 11:05:24 +08:00
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
2026-06-05 09:34:29 +08:00
self.plotLabel.append(self.currentLabel)
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time)
2026-06-07 11:05:24 +08:00
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
2026-06-05 09:34:29 +08:00
self.interval_epoch[1] \
2026-06-07 11:05:24 +08:00
+ self.zmqServer.event_inner_idx:
2026-06-05 09:34:29 +08:00
time.sleep(0.0001)
return
2026-06-07 11:05:24 +08:00
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.zmqServer.event_inner_idx])
2026-06-05 09:34:29 +08:00
start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:,
2026-06-07 11:05:24 +08:00
self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
2026-06-05 09:34:29 +08:00
self.plotData.append(
2026-06-07 11:05:24 +08:00
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
2026-06-05 09:34:29 +08:00
test_data = data[np.newaxis, np.newaxis, :, :]
test_data = torch.from_numpy(test_data)
test_data = Variable(test_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
Cls = self.model(test_data)
y_pred = torch.max(Cls, 1)[1]
self.plotLabel.append(int(y_pred.item()))
print('运动意图识别: ', y_pred)
2026-06-07 11:05:24 +08:00
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
2026-06-05 09:34:29 +08:00
end = time.time()
print(f'发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态
2026-06-07 11:05:24 +08:00
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
2026-06-06 14:40:07 +08:00
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
2026-06-05 09:34:29 +08:00
time.sleep(0.005)
return
2026-06-07 11:05:24 +08:00
self.zmqServer.paradigmBuffer.getData(25)
2026-06-05 09:34:29 +08:00
2026-06-06 14:40:07 +08:00
# def decoder_concentration(self):
# if self.zmqServer.state_mode == 'predict':
# if self.zmqServer.StartDecode:
# self.zmqServer.StartDecode = False
# self.thread_data_server.ResetAll()
# now = datetime.now()
# formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
# print('启动专注力预测 ', formatted_time)
# if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
# time.sleep(0.005)
# return
# if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
# return
# data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
# result = self.calculate.queueOpt(data)
# if result is not None:
# self.zmqClient.send_to_all('result', int(result))
# else: # 休息状态
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
# if self.thread_data_server.GetDataLenCount() < 25:
# time.sleep(0.005)
# return
# self.thread_data_server.getData(25)
2026-06-05 09:34:29 +08:00
def stop(self):
'''
停止运行
@return:
'''
self.zmqServer.stop()
2026-06-08 11:56:42 +08:00
self.sliding_filter.stop()
2026-06-05 09:34:29 +08:00
self.Runing=False
def reset_state(self):
"""清空解码器状态和缓存数据"""
# 重置设备层缓存
2026-06-06 17:08:09 +08:00
self.zmqServer.reset_state()
2026-06-05 09:34:29 +08:00
# 重置解码状态
self.decodingSteps = 0
self.calculateCount = 0
# 重置训练数据
self.plotData = []
self.plotLabel = []
self.trainData = []
self.trainLabel = []
self.currentLabel = -1
self.train_started = False
self.load_model = False
# 重置多进程队列,确保切换 decoder 时旧数据不会泄漏到新队列
if hasattr(self, 'mp_data_queue'):
while not self.mp_data_queue.empty():
try: self.mp_data_queue.get_nowait()
except Empty: pass
if hasattr(self, 'mp_result_queue'):
while not self.mp_result_queue.empty():
try: self.mp_result_queue.get_nowait()
except Empty: pass