Files
bci_algo/Decoder.py

662 lines
34 KiB
Python
Raw Normal View History

2026-06-05 09:34:29 +08:00
import ast
2026-06-06 14:51:00 +08:00
import glob
2026-06-06 14:40:07 +08:00
import os
import sys
2026-06-05 09:34:29 +08:00
import threading
from datetime import datetime
import multiprocessing as mp
import numpy as np
import time
import torch
from queue import Empty
from scipy import signal
from torch.autograd import Variable
2026-06-06 14:40:07 +08:00
# from Device.SunnyLinker import SunnyLinker64
2026-06-05 09:34:29 +08:00
from SSMVEP.algorithm.tdca import TDCA
from SSMVEP.algorithm.base import generate_cca_references
from concentration.algorithm.calculate_focus import Calculate
from blinkdetection.algorithm.eye_detection import blink_detection
from Zmq.zmqServer import zmqServer
from Zmq.zmqClient import zmqClient
from MI.Algorithm.conformer_2class import onlineTrain
from PubLibrary.InifileHelper import IniRead
2026-06-06 14:40:07 +08:00
from logs.log import algo_log
2026-06-05 09:34:29 +08:00
from SSVEP.dwfbcca import FbccaDw
2026-06-06 14:40:07 +08:00
# from Tools.plot_MI_EEG import plotMain
2026-06-05 09:34:29 +08:00
from collections import deque
2026-06-06 14:40:07 +08:00
def get_root_path():
"""
Nuitka 打包专用获取程序根目录.py .exe 所在目录
"""
if getattr(sys, 'frozen', False):
# 打包后:返回 exe 所在目录
return os.path.dirname(sys.executable)
else:
# 开发时:返回 py 文件所在目录
return os.path.dirname(os.path.abspath(__file__))
MODEL_FOLDER = "online_Models"
class Decoder_main(threading.Thread):
def __init__(self, device_info=None):
2026-06-05 09:34:29 +08:00
threading.Thread.__init__(self)
2026-06-06 14:40:07 +08:00
self.device_info = {
'sample_rate': device_info['sample_rate'],
'frame_points': device_info['frame_points'],
'channel_nums': device_info['channel_nums'],
'channel_names': device_info['channel_names'],
'channel_index': device_info['channel_index'],
}
2026-06-05 09:34:29 +08:00
self.Runing=True
self.decoder = None
self.decoder_class = None #解码器类别
2026-06-06 14:40:07 +08:00
# 与采集设备通信的状态码0为异常1为正常
# self.status_code = 0
# self.device_info['sample_rate'] = 250 # 采样率
# self.energy = 0 # 电量
2026-06-05 09:34:29 +08:00
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
2026-06-06 14:40:07 +08:00
self.zmqServer = zmqServer(device_info=self.device_info)
self.zmqServer.start()
2026-06-05 09:34:29 +08:00
2026-06-06 14:40:07 +08:00
# self.zmqClient = zmqClient(_upper_host, _upper_port)
# self.zmqClient.set_zmq_server(self.zmqServer)
# self.zmqClient.connect()
2026-06-05 09:34:29 +08:00
2026-06-06 14:40:07 +08:00
# def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
# self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
# _device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
# _device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
# _upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
# _upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
# if self.DeviceType == 1:
# self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.device_info['sample_rate'], 64, method='tcp')
# self.thread_data_server.host = _device_host
# self.thread_data_server.port = _device_port
# self.thread_data_server.toUv = True
# self.thread_data_server.start()
2026-06-05 09:34:29 +08:00
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples)
energy = np.mean(np.var(data, axis=1)) # 各通道方差均值
if energy > threshold:
return False
return True
def init_Decoder(self,decoder_class):
'''
初始化解码器
:param decoder_class: 'ssvep' or 'ssmvep' or 'mi' or 'concentration' or ''
:return:
'''
self.decoder_class = decoder_class
if decoder_class == 'ssvep' or decoder_class == 'pvs':
self.n_chan = 8
2026-06-06 14:40:07 +08:00
# self.thread_data_server.interval_inited = False
2026-06-05 09:34:29 +08:00
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
self.ListFreq = self.zmqServer.targetFreqs
self.num_target = len(self.ListFreq)
if self.num_target == 0:
return
# 初始化对象 二代算法
2026-06-06 14:40:07 +08:00
self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
2026-06-05 09:34:29 +08:00
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
# frequence band
self.dw.filterFrequenceBank()
self.dw.setNotchFilterPara()
self.calculateCount = 0
2026-06-06 14:40:07 +08:00
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
2026-06-05 09:34:29 +08:00
self.dw.filterInit()
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
elif decoder_class == 'ssmvep':
2026-06-06 14:40:07 +08:00
self.zmqServer.interval_init(decoder_class)
2026-06-05 09:34:29 +08:00
self.n_chan = 8
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 10 # 单类别数量
self.num_target = 2 # 分类目标数目
self.list_freqs = np.array([8, 9]) # 刺激频率
self.list_phase = np.array([0, 0]) # 相位
self.tdca = TDCA(padding_len=5, n_components=1)
2026-06-06 14:40:07 +08:00
self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
2026-06-05 09:34:29 +08:00
phases=self.list_phase, n_harmonics=5)
self.parameter_init(5,45)
elif decoder_class == 'mi' or decoder_class == 'ma':
2026-06-06 14:40:07 +08:00
self.zmqServer.interval_init(decoder_class)
2026-06-05 09:34:29 +08:00
self.n_chan = 21
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
self.single_train = 40 # 单类别数量
self.num_target = 2 # 分类目标数目
self.parameter_init(8, 30)
2026-06-06 09:16:49 +08:00
# elif decoder_class == 'concentration':
# self.thread_data_server.interval_inited = False
# self.n_chan = 6
# self.win_len = 10
# self.win_step = 1
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
2026-06-06 14:40:07 +08:00
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
2026-06-06 09:16:49 +08:00
# self.interval_epoch = [0, 1]
# self.parameter_init(2, 40)
# # self.eegQueue moved to Calculate class
2026-06-05 09:34:29 +08:00
2026-06-06 09:16:49 +08:00
# elif decoder_class == 'blink':
# self.n_chan = 2
# self.l_freq = 0.1 # 带通滤波器低频截止
# self.h_freq = 8.0 # 带通滤波器高频截止
# self.total_samples = 0 # 总采样点数
# self.window_ms = 600 # 检测窗口大小 (ms)
# self.step_ms = 100 # 滑动步长 (ms)
2026-06-06 14:40:07 +08:00
# self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
# self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
2026-06-06 09:16:49 +08:00
# self.buffer_size = self.window_samples + self.step_samples * 5
# self.fp1_buffer = deque(maxlen=self.buffer_size)
# self.fp2_buffer = deque(maxlen=self.buffer_size)
# self.sample_counter = 0
# # 预计算滤波器系数,避免在循环中重复设计
# self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink'))
# self.blink_count = 0 # 单次眨眼的次数
# self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引)
# self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳
# self.double_blink_count = 0 # 连续两次眨眼的次数
# self.double_blink_events = [] # 连续眨眼事件记录
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
# self.blink_events = []
2026-06-06 14:40:07 +08:00
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
2026-06-05 09:34:29 +08:00
def parameter_init(self,bandPass_low,bandPass_high):
2026-06-06 14:40:07 +08:00
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
2026-06-05 09:34:29 +08:00
self.trainData = [] #训练数据
self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据
self.plotLabel = [] #报告分析标签
self.currentLabel = -1 #刺激界面当前显示的训练标签
self.train_started = False #是否开始训练模型
self.load_model = False # 调用模型是否完成的标志
2026-06-06 14:40:07 +08:00
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波250是采样率30是质量因子
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
2026-06-05 09:34:29 +08:00
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
2026-06-06 15:13:23 +08:00
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
2026-06-06 14:51:00 +08:00
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
os.remove(old_pth)
2026-06-05 09:34:29 +08:00
self.modelPath = ''.join([filePath, fileName, '.pth'])
2026-06-06 14:51:00 +08:00
self.mp_data_queue = mp.Queue()
self.mp_result_queue = mp.Queue()
2026-06-05 09:34:29 +08:00
def preprocess(self, signal_data):
# # 计算每行的平均值
row_means = np.mean(signal_data, axis=-1, keepdims=True)
# 对每一行去均值
signal_data = signal_data - row_means
signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波
signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波
return signal_data
def run(self):
while self.Runing:
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}")
self.zmqServer.decoder_switch = False
self.zmqServer.changeTarget = False
self.reset_state() # 切换前先统一清理旧状态
self.init_Decoder(self.zmqServer.decoder_class)
# 同步信息
if self.zmqServer.state_mode == 'sync':
2026-06-06 14:40:07 +08:00
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
2026-06-05 09:34:29 +08:00
self.zmqServer.state_mode = 'rest'
2026-06-06 14:40:07 +08:00
# # 状态异常,报告上位机
# if self.status_code != self.thread_data_server.status_code:
# self.status_code = self.thread_data_server.status_code
# self.zmqClient.send_to_all('status_code', int(self.status_code))
# print('status code')
# # 返回电量
# if self.energy != self.thread_data_server.energy:
# self.energy = self.thread_data_server.energy
# self.zmqClient.send_to_all('energy', int(self.energy))
# print('energy')
# if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
# self.thread_data_server.Impedance(True)
# print('Impedance')
# self.zmqServer.open_Impedance = -1
# elif self.zmqServer.open_Impedance == False:
# self.thread_data_server.Impedance(False)
# self.zmqServer.open_Impedance = -1
# if self.zmqServer.get_Impedance: # 返回阻抗值
# # print(self.zmqServer.get_Impedance)
# # print(self.thread_data_server.GetDataLenCount())
# if self.thread_data_server.GetDataLenCount() > 250:
# Impe_data = self.thread_data_server.getData(250)
# # 计算阻抗
# imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
# self.zmqClient.send_to_all('impedance', imps.tolist())
# else:
# pass
# if self.zmqServer.getReport: #返回训练报告内容
# self.zmqServer.getReport = False
# allData = np.array(self.plotData)
# allLabel = np.array(self.plotLabel) + 1
# nTrials = min(len(allLabel),len(allData))
# if nTrials < 30:
# self.zmqClient.send_to_all('miReport',0)
# else:
# allData = allData[:nTrials]
# allLabel = allLabel[:nTrials]
# ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
# 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
# compare_names = ['C3', 'CZ', 'C4']
# miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
# fs=self.device_info['sample_rate'])
# self.zmqClient.send_to_all('miReport',miReport)
2026-06-05 09:34:29 +08:00
# --- 取数优先:先执行 decoder消费环形缓冲再处理 plot/report 等重负载 ---
try:
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.decoder_SSVEP()
elif self.decoder_class == 'ssmvep':
self.decoder_SSMVEP()
elif self.decoder_class == 'mi':
self.decoder_MI()
2026-06-06 14:40:07 +08:00
# elif self.decoder_class == 'concentration':
# self.decoder_concentration()
# elif self.decoder_class == 'blink':
# self.decoder_blink()
# else:
# self.
# # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
# # if self.thread_data_server.GetDataLenCount() < 25:
# # time.sleep(0.005)
# # continue;
# # self.thread_data_server.getData(25)
2026-06-05 09:34:29 +08:00
except Exception as e:
2026-06-06 14:40:07 +08:00
algo_log(f"Decoder Loop Error: {e}")
2026-06-05 09:34:29 +08:00
time.sleep(0.1) # Prevent CPU spin if error is persistent
def decoder_SSVEP(self):
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
self.decodingSteps = 1
2026-06-06 14:40:07 +08:00
self.zmqServer.paradigmBuffer.ResetAllPara()
2026-06-05 09:34:29 +08:00
print('启动预测')
2026-06-06 14:40:07 +08:00
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
2026-06-05 09:34:29 +08:00
time.sleep(0.005)
return
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
return
2026-06-06 14:40:07 +08:00
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
2026-06-05 09:34:29 +08:00
data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
self.dw.warmFilter(data) # 预热
self.decodingSteps = 2
print('预热数据完成。开始预测')
return
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
self.calculateCount += 1
if choosenNum != -1 and self.is_valid_signal(data):
self.decodingSteps = 3
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息
2026-06-06 14:40:07 +08:00
self.zmqServer.broadcast_message('result', int(choosenNum))
2026-06-05 09:34:29 +08:00
self.decodingSteps = 0
print('发送给界面完成。')
def decoder_SSMVEP(self):
'''模型训练'''
if self.load_model == False and all(
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel)
print(np.shape(self.trainData), (self.trainLabel))
# 保存多个数组到文件
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel)
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('模型训练完成', formatted_time)
self.load_model = True
2026-06-06 14:40:07 +08:00
self.zmqServer.broadcast_message('paradigm', 1)
2026-06-05 09:34:29 +08:00
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False
2026-06-06 14:40:07 +08:00
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
2026-06-05 09:34:29 +08:00
self.train_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
2026-06-06 14:40:07 +08:00
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
2026-06-05 09:34:29 +08:00
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[
0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]]
print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成
time.sleep(0.01)
return
else: # 已有模型
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time)
2026-06-06 14:40:07 +08:00
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
2026-06-05 09:34:29 +08:00
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
data = self.thread_data_server.get_SSMVEPData() # 读取全部数据
print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx])
data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:,
self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
pad_eeg_test = np.zeros(
2026-06-06 14:40:07 +08:00
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
2026-06-05 09:34:29 +08:00
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0]
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
self.zmqClient.send_to_all('result', int(choosenNum))
print('发送给界面完成。')
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
2026-06-06 14:40:07 +08:00
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
2026-06-05 09:34:29 +08:00
time.sleep(0.005)
return
self.thread_data_server.getData(25)
def decoder_MI(self):
'''模型训练'''
if self.train_started == False and all(
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
self.train_started = True
self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel) + 1
# print('训练集:',np.shape(self.trainData), (self.trainLabel))
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
p.start()
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
'n_chan': self.n_chan})
'''检查模型是否训练完成,调用'''
if self.load_model == False and self.train_started == True:
try:
result = self.mp_result_queue.get_nowait()
if result['status'] == 'success':
print("模型训练完成,加载新模型")
# 调用模型
self.model = torch.load(self.modelPath, weights_only=False)
self.model.eval()
# 模型预热
warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000))
warmup_data = torch.from_numpy(warmup_data)
warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
_ = self.model(warmup_data)
self.load_model = True
self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机
else:
print("训练失败:", result['msg'])
except Empty:
pass # 还没完成
except Exception as e:
print('模型调用失败: ', e)
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
if self.zmqServer.StartTrain:
self.currentLabel = self.zmqServer.currentLabel
self.zmqServer.StartTrain = False
2026-06-06 14:40:07 +08:00
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
2026-06-05 09:34:29 +08:00
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
2026-06-06 14:40:07 +08:00
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
2026-06-05 09:34:29 +08:00
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
print('训练集:', np.shape(self.trainData))
self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time)
2026-06-06 14:40:07 +08:00
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
2026-06-05 09:34:29 +08:00
self.interval_epoch[1] \
+ self.thread_data_server.event_inner_idx:
time.sleep(0.0001)
return
originalData = self.thread_data_server.get_MIData() # 读取全部数据
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx])
start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:,
self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
self.plotData.append(
originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
test_data = data[np.newaxis, np.newaxis, :, :]
test_data = torch.from_numpy(test_data)
test_data = Variable(test_data.type(torch.cuda.FloatTensor))
with torch.no_grad():
Cls = self.model(test_data)
y_pred = torch.max(Cls, 1)[1]
self.plotLabel.append(int(y_pred.item()))
print('运动意图识别: ', y_pred)
self.zmqClient.send_to_all('result', int(y_pred.item()))
end = time.time()
print(f'发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
2026-06-06 14:40:07 +08:00
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
2026-06-05 09:34:29 +08:00
time.sleep(0.005)
return
self.thread_data_server.getData(25)
2026-06-06 14:40:07 +08:00
# def decoder_concentration(self):
# if self.zmqServer.state_mode == 'predict':
# if self.zmqServer.StartDecode:
# self.zmqServer.StartDecode = False
# self.thread_data_server.ResetAll()
# now = datetime.now()
# formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
# print('启动专注力预测 ', formatted_time)
# if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
# time.sleep(0.005)
# return
# if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
# return
# data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
# result = self.calculate.queueOpt(data)
# if result is not None:
# self.zmqClient.send_to_all('result', int(result))
# else: # 休息状态
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
# if self.thread_data_server.GetDataLenCount() < 25:
# time.sleep(0.005)
# return
# self.thread_data_server.getData(25)
2026-06-05 09:34:29 +08:00
#### Blink detection #####
2026-06-06 14:40:07 +08:00
# def check_double_blink(self, current_time):
# """
# 检查是否检测到连续两次眨眼
# @param current_time: 当前眨眼时间戳
# @return: True表示检测到连续两次眨眼
# """
# if len(self.blink_timestamps) < 2:
# return False
# # 检查是否在去抖期内
# if self.last_double_blink_time > 0:
# time_since_last_double_blink = current_time - self.last_double_blink_time
# if time_since_last_double_blink < self.double_blink_jitter:
# return False # 在去抖期内,忽略连续眨眼检测
# last_time = self.blink_timestamps[-1] # 当前眨眼
# prev_time = self.blink_timestamps[-2] # 上次眨眼
# interval = last_time - prev_time
# if interval <= self.double_blink_interval:
# return True
# return False
# def process_blink_detection(self):
# """
# 在缓冲区数据上执行,单次眨眼检测
# """
# if len(self.fp1_buffer) < self.window_samples:
# return
# fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
# fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
# # 计算FP1和FP2的平均
# fp12_mean = (fp1_data + fp2_data) / 2.0
# # 带通滤波
# try:
# fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
# except Exception as e:
# print(f"Filter error: {e}")
# return
# F = np.diff(fp12_filtered)
# if len(F) < 3:
# return
# b, d, e = blink_detection(F, self.device_info['sample_rate'], self.Dmin, self.Dmax, self.EMin, self.EMax)
# if b == 1:
# samples_since_last = self.total_samples - self.last_blink_time
# time_since_last_ms = (samples_since_last / self.device_info['sample_rate']) * 1000
# if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
# self.blink_count += 1
# self.last_blink_time = self.total_samples
# current_time = time.time()
# self.blink_timestamps.append(current_time)
# blink_event = {
# 'count': self.blink_count,
# 'time': current_time,
# 'sample_index': self.total_samples,
# 'duration_ms': d,
# 'energy': e
# }
# self.blink_events.append(blink_event)
# self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
# if self.check_double_blink(current_time):
# self.double_blink_count += 1
# interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
# double_blink_event = {
# 'double_blink_count': self.double_blink_count,
# 'blink1_time': self.blink_timestamps[-2],
# 'blink2_time': self.blink_timestamps[-1],
# 'interval': interval
# }
# self.double_blink_events.append(double_blink_event)
# self.last_double_blink_time = current_time
# self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
# def decoder_blink(self):
# if self.thread_data_server.GetDataLenCount() < 50:
# time.sleep(0.005)
# return
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
# data = self.thread_data_server.get_blinkData(50)
# fp1_data = data[0, :] # ch1 (相当于FP1)
# fp2_data = data[1, :] # ch2 (相当于FP2)
# for i in range(len(fp1_data)):
# self.fp1_buffer.append(fp1_data[i])
# self.fp2_buffer.append(fp2_data[i])
# self.total_samples += 1
# self.sample_counter += 1
# if self.sample_counter >= self.step_samples:
# self.process_blink_detection()
# self.sample_counter = 0
2026-06-05 09:34:29 +08:00
def stop(self):
'''
停止运行
@return:
'''
self.zmqServer.stop()
self.Runing=False
def reset_state(self):
"""清空解码器状态和缓存数据"""
# 重置设备层缓存
self.thread_data_server.reset_state()
# 重置解码状态
self.decodingSteps = 0
self.calculateCount = 0
# 重置训练数据
self.plotData = []
self.plotLabel = []
self.trainData = []
self.trainLabel = []
self.currentLabel = -1
self.train_started = False
self.load_model = False
# 重置多进程队列,确保切换 decoder 时旧数据不会泄漏到新队列
if hasattr(self, 'mp_data_queue'):
while not self.mp_data_queue.empty():
try: self.mp_data_queue.get_nowait()
except Empty: pass
if hasattr(self, 'mp_result_queue'):
while not self.mp_result_queue.empty():
try: self.mp_result_queue.get_nowait()
except Empty: pass