add buffer
This commit is contained in:
476
Decoder.py
476
Decoder.py
@@ -1,4 +1,6 @@
|
|||||||
import ast
|
import ast
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
@@ -8,7 +10,7 @@ import torch
|
|||||||
from queue import Empty
|
from queue import Empty
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from Device.SunnyLinker import SunnyLinker64
|
# from Device.SunnyLinker import SunnyLinker64
|
||||||
from SSMVEP.algorithm.tdca import TDCA
|
from SSMVEP.algorithm.tdca import TDCA
|
||||||
from SSMVEP.algorithm.base import generate_cca_references
|
from SSMVEP.algorithm.base import generate_cca_references
|
||||||
from concentration.algorithm.calculate_focus import Calculate
|
from concentration.algorithm.calculate_focus import Calculate
|
||||||
@@ -17,49 +19,70 @@ from Zmq.zmqServer import zmqServer
|
|||||||
from Zmq.zmqClient import zmqClient
|
from Zmq.zmqClient import zmqClient
|
||||||
from MI.Algorithm.conformer_2class import onlineTrain
|
from MI.Algorithm.conformer_2class import onlineTrain
|
||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
from logs.log import algo_log
|
||||||
from SSVEP.dwfbcca import FbccaDw
|
from SSVEP.dwfbcca import FbccaDw
|
||||||
from Tools.plot_MI_EEG import plotMain
|
# from Tools.plot_MI_EEG import plotMain
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
class Decoder_main(threading.Thread, device_type):
|
|
||||||
def __init__(self, device_type=None):
|
def get_root_path():
|
||||||
|
"""
|
||||||
|
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
|
||||||
|
"""
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
# 打包后:返回 exe 所在目录
|
||||||
|
return os.path.dirname(sys.executable)
|
||||||
|
else:
|
||||||
|
# 开发时:返回 py 文件所在目录
|
||||||
|
return os.path.dirname(os.path.abspath(__file__))
|
||||||
|
MODEL_FOLDER = "online_Models"
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder_main(threading.Thread):
|
||||||
|
def __init__(self, device_info=None):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
|
self.device_info = {
|
||||||
|
'sample_rate': device_info['sample_rate'],
|
||||||
|
'frame_points': device_info['frame_points'],
|
||||||
|
'channel_nums': device_info['channel_nums'],
|
||||||
|
'channel_names': device_info['channel_names'],
|
||||||
|
'channel_index': device_info['channel_index'],
|
||||||
|
}
|
||||||
self.Runing=True
|
self.Runing=True
|
||||||
self.decoder = None
|
self.decoder = None
|
||||||
|
|
||||||
self.fs = 250 # 采样率
|
|
||||||
self.energy = 0 # 电量
|
|
||||||
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
|
||||||
self.decoder_class = None #解码器类别
|
self.decoder_class = None #解码器类别
|
||||||
|
|
||||||
|
# 与采集设备通信的状态码,0为异常,1为正常
|
||||||
|
# self.status_code = 0
|
||||||
|
# self.device_info['sample_rate'] = 250 # 采样率
|
||||||
|
# self.energy = 0 # 电量
|
||||||
|
|
||||||
|
|
||||||
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
||||||
self.device_info = {
|
|
||||||
'device_type': None,
|
|
||||||
'sample_rate': None,
|
|
||||||
'channel_num': None,
|
|
||||||
}
|
|
||||||
|
|
||||||
def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
|
self.zmqServer = zmqServer(device_info=self.device_info)
|
||||||
self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
|
|
||||||
_device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
|
|
||||||
_device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
|
|
||||||
_upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
|
|
||||||
_upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
|
|
||||||
|
|
||||||
if self.DeviceType == 1:
|
|
||||||
self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp')
|
|
||||||
self.thread_data_server.host = _device_host
|
|
||||||
self.thread_data_server.port = _device_port
|
|
||||||
|
|
||||||
self.thread_data_server.toUv = True
|
|
||||||
self.thread_data_server.start()
|
|
||||||
|
|
||||||
self.zmqServer = zmqServer()
|
|
||||||
self.zmqServer.start()
|
self.zmqServer.start()
|
||||||
|
|
||||||
self.zmqClient = zmqClient(_upper_host, _upper_port)
|
# self.zmqClient = zmqClient(_upper_host, _upper_port)
|
||||||
self.zmqClient.set_zmq_server(self.zmqServer)
|
# self.zmqClient.set_zmq_server(self.zmqServer)
|
||||||
self.zmqClient.connect()
|
# self.zmqClient.connect()
|
||||||
|
|
||||||
|
|
||||||
|
# def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
|
||||||
|
# self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
|
||||||
|
# _device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
|
||||||
|
# _device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
|
||||||
|
# _upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
|
||||||
|
# _upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
|
||||||
|
|
||||||
|
# if self.DeviceType == 1:
|
||||||
|
# self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.device_info['sample_rate'], 64, method='tcp')
|
||||||
|
# self.thread_data_server.host = _device_host
|
||||||
|
# self.thread_data_server.port = _device_port
|
||||||
|
|
||||||
|
# self.thread_data_server.toUv = True
|
||||||
|
# self.thread_data_server.start()
|
||||||
|
|
||||||
|
|
||||||
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
||||||
# data: (chans, samples)
|
# data: (chans, samples)
|
||||||
@@ -76,26 +99,25 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
self.decoder_class = decoder_class
|
self.decoder_class = decoder_class
|
||||||
if decoder_class == 'ssvep' or decoder_class == 'pvs':
|
if decoder_class == 'ssvep' or decoder_class == 'pvs':
|
||||||
self.n_chan = 8
|
self.n_chan = 8
|
||||||
self.thread_data_server.interval_inited = False
|
# self.thread_data_server.interval_inited = False
|
||||||
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
||||||
self.ListFreq = self.zmqServer.targetFreqs
|
self.ListFreq = self.zmqServer.targetFreqs
|
||||||
self.num_target = len(self.ListFreq)
|
self.num_target = len(self.ListFreq)
|
||||||
if self.num_target == 0:
|
if self.num_target == 0:
|
||||||
return
|
return
|
||||||
# 初始化对象 二代算法
|
# 初始化对象 二代算法
|
||||||
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5,
|
self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5,
|
||||||
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
|
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
|
||||||
# frequence band
|
# frequence band
|
||||||
self.dw.filterFrequenceBank()
|
self.dw.filterFrequenceBank()
|
||||||
self.dw.setNotchFilterPara()
|
self.dw.setNotchFilterPara()
|
||||||
self.calculateCount = 0
|
self.calculateCount = 0
|
||||||
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs),
|
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5)
|
||||||
5)
|
|
||||||
self.dw.filterInit()
|
self.dw.filterInit()
|
||||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
|
|
||||||
elif decoder_class == 'ssmvep':
|
elif decoder_class == 'ssmvep':
|
||||||
self.thread_data_server.interval_init(decoder_class)
|
self.zmqServer.interval_init(decoder_class)
|
||||||
self.n_chan = 8
|
self.n_chan = 8
|
||||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||||
@@ -104,12 +126,12 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
self.list_freqs = np.array([8, 9]) # 刺激频率
|
self.list_freqs = np.array([8, 9]) # 刺激频率
|
||||||
self.list_phase = np.array([0, 0]) # 相位
|
self.list_phase = np.array([0, 0]) # 相位
|
||||||
self.tdca = TDCA(padding_len=5, n_components=1)
|
self.tdca = TDCA(padding_len=5, n_components=1)
|
||||||
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length,
|
self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length,
|
||||||
phases=self.list_phase, n_harmonics=5)
|
phases=self.list_phase, n_harmonics=5)
|
||||||
self.parameter_init(5,45)
|
self.parameter_init(5,45)
|
||||||
|
|
||||||
elif decoder_class == 'mi' or decoder_class == 'ma':
|
elif decoder_class == 'mi' or decoder_class == 'ma':
|
||||||
self.thread_data_server.interval_init(decoder_class)
|
self.zmqServer.interval_init(decoder_class)
|
||||||
self.n_chan = 21
|
self.n_chan = 21
|
||||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||||
@@ -124,7 +146,7 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
# self.win_len = 10
|
# self.win_len = 10
|
||||||
# self.win_step = 1
|
# self.win_step = 1
|
||||||
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
|
# self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
|
||||||
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len)
|
# self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len)
|
||||||
# self.interval_epoch = [0, 1]
|
# self.interval_epoch = [0, 1]
|
||||||
# self.parameter_init(2, 40)
|
# self.parameter_init(2, 40)
|
||||||
# # self.eegQueue moved to Calculate class
|
# # self.eegQueue moved to Calculate class
|
||||||
@@ -136,8 +158,8 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
# self.total_samples = 0 # 总采样点数
|
# self.total_samples = 0 # 总采样点数
|
||||||
# self.window_ms = 600 # 检测窗口大小 (ms)
|
# self.window_ms = 600 # 检测窗口大小 (ms)
|
||||||
# self.step_ms = 100 # 滑动步长 (ms)
|
# self.step_ms = 100 # 滑动步长 (ms)
|
||||||
# self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点
|
# self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点
|
||||||
# self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点
|
# self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点
|
||||||
# self.buffer_size = self.window_samples + self.step_samples * 5
|
# self.buffer_size = self.window_samples + self.step_samples * 5
|
||||||
# self.fp1_buffer = deque(maxlen=self.buffer_size)
|
# self.fp1_buffer = deque(maxlen=self.buffer_size)
|
||||||
# self.fp2_buffer = deque(maxlen=self.buffer_size)
|
# self.fp2_buffer = deque(maxlen=self.buffer_size)
|
||||||
@@ -151,11 +173,11 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
# self.double_blink_events = [] # 连续眨眼事件记录
|
# self.double_blink_events = [] # 连续眨眼事件记录
|
||||||
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
|
# self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
|
||||||
# self.blink_events = []
|
# self.blink_events = []
|
||||||
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band')
|
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
|
||||||
|
|
||||||
def parameter_init(self,bandPass_low,bandPass_high):
|
def parameter_init(self,bandPass_low,bandPass_high):
|
||||||
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
|
||||||
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch
|
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
|
||||||
self.trainData = [] #训练数据
|
self.trainData = [] #训练数据
|
||||||
self.trainLabel = [] #训练标签
|
self.trainLabel = [] #训练标签
|
||||||
self.plotData = [] #报告分析数据
|
self.plotData = [] #报告分析数据
|
||||||
@@ -163,10 +185,10 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
self.currentLabel = -1 #刺激界面当前显示的训练标签
|
self.currentLabel = -1 #刺激界面当前显示的训练标签
|
||||||
self.train_started = False #是否开始训练模型
|
self.train_started = False #是否开始训练模型
|
||||||
self.load_model = False # 调用模型是否完成的标志
|
self.load_model = False # 调用模型是否完成的标志
|
||||||
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
||||||
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
||||||
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||||
filePath = './online_Models/'
|
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
|
||||||
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
||||||
self.mp_data_queue = mp.Queue() #多进程传参队列
|
self.mp_data_queue = mp.Queue() #多进程传参队列
|
||||||
self.mp_result_queue = mp.Queue() #多进程结果队列
|
self.mp_result_queue = mp.Queue() #多进程结果队列
|
||||||
@@ -192,54 +214,55 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
|
|
||||||
# 同步信息
|
# 同步信息
|
||||||
if self.zmqServer.state_mode == 'sync':
|
if self.zmqServer.state_mode == 'sync':
|
||||||
self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||||
self.zmqServer.state_mode = 'rest'
|
self.zmqServer.state_mode = 'rest'
|
||||||
# 状态异常,报告上位机
|
|
||||||
if self.status_code != self.thread_data_server.status_code:
|
|
||||||
self.status_code = self.thread_data_server.status_code
|
|
||||||
self.zmqClient.send_to_all('status_code', int(self.status_code))
|
|
||||||
print('status code')
|
|
||||||
|
|
||||||
# 返回电量
|
# # 状态异常,报告上位机
|
||||||
if self.energy != self.thread_data_server.energy:
|
# if self.status_code != self.thread_data_server.status_code:
|
||||||
self.energy = self.thread_data_server.energy
|
# self.status_code = self.thread_data_server.status_code
|
||||||
self.zmqClient.send_to_all('energy', int(self.energy))
|
# self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||||
print('energy')
|
# print('status code')
|
||||||
|
|
||||||
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
# # 返回电量
|
||||||
self.thread_data_server.Impedance(True)
|
# if self.energy != self.thread_data_server.energy:
|
||||||
print('Impedance')
|
# self.energy = self.thread_data_server.energy
|
||||||
self.zmqServer.open_Impedance = -1
|
# self.zmqClient.send_to_all('energy', int(self.energy))
|
||||||
elif self.zmqServer.open_Impedance == False:
|
# print('energy')
|
||||||
self.thread_data_server.Impedance(False)
|
|
||||||
self.zmqServer.open_Impedance = -1
|
|
||||||
|
|
||||||
if self.zmqServer.get_Impedance: # 返回阻抗值
|
# if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||||
# print(self.zmqServer.get_Impedance)
|
# self.thread_data_server.Impedance(True)
|
||||||
# print(self.thread_data_server.GetDataLenCount())
|
# print('Impedance')
|
||||||
if self.thread_data_server.GetDataLenCount() > 250:
|
# self.zmqServer.open_Impedance = -1
|
||||||
Impe_data = self.thread_data_server.getData(250)
|
# elif self.zmqServer.open_Impedance == False:
|
||||||
# 计算阻抗
|
# self.thread_data_server.Impedance(False)
|
||||||
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
|
# self.zmqServer.open_Impedance = -1
|
||||||
self.zmqClient.send_to_all('impedance', imps.tolist())
|
|
||||||
else:
|
# if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||||
pass
|
# # print(self.zmqServer.get_Impedance)
|
||||||
if self.zmqServer.getReport: #返回训练报告内容
|
# # print(self.thread_data_server.GetDataLenCount())
|
||||||
self.zmqServer.getReport = False
|
# if self.thread_data_server.GetDataLenCount() > 250:
|
||||||
allData = np.array(self.plotData)
|
# Impe_data = self.thread_data_server.getData(250)
|
||||||
allLabel = np.array(self.plotLabel) + 1
|
# # 计算阻抗
|
||||||
nTrials = min(len(allLabel),len(allData))
|
# imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
|
||||||
if nTrials < 30:
|
# self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||||
self.zmqClient.send_to_all('miReport',0)
|
# else:
|
||||||
else:
|
# pass
|
||||||
allData = allData[:nTrials]
|
# if self.zmqServer.getReport: #返回训练报告内容
|
||||||
allLabel = allLabel[:nTrials]
|
# self.zmqServer.getReport = False
|
||||||
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
# allData = np.array(self.plotData)
|
||||||
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
# allLabel = np.array(self.plotLabel) + 1
|
||||||
compare_names = ['C3', 'CZ', 'C4']
|
# nTrials = min(len(allLabel),len(allData))
|
||||||
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
|
# if nTrials < 30:
|
||||||
fs=self.fs)
|
# self.zmqClient.send_to_all('miReport',0)
|
||||||
self.zmqClient.send_to_all('miReport',miReport)
|
# else:
|
||||||
|
# allData = allData[:nTrials]
|
||||||
|
# allLabel = allLabel[:nTrials]
|
||||||
|
# ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
||||||
|
# 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
||||||
|
# compare_names = ['C3', 'CZ', 'C4']
|
||||||
|
# miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
|
||||||
|
# fs=self.device_info['sample_rate'])
|
||||||
|
# self.zmqClient.send_to_all('miReport',miReport)
|
||||||
|
|
||||||
|
|
||||||
# --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 ---
|
# --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 ---
|
||||||
@@ -250,34 +273,33 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
self.decoder_SSMVEP()
|
self.decoder_SSMVEP()
|
||||||
elif self.decoder_class == 'mi':
|
elif self.decoder_class == 'mi':
|
||||||
self.decoder_MI()
|
self.decoder_MI()
|
||||||
elif self.decoder_class == 'concentration':
|
# elif self.decoder_class == 'concentration':
|
||||||
self.decoder_concentration()
|
# self.decoder_concentration()
|
||||||
elif self.decoder_class == 'blink':
|
# elif self.decoder_class == 'blink':
|
||||||
self.decoder_blink()
|
# self.decoder_blink()
|
||||||
else:
|
# else:
|
||||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
# self.
|
||||||
if self.thread_data_server.GetDataLenCount() < 25:
|
# # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
time.sleep(0.005)
|
# # if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
continue;
|
# # time.sleep(0.005)
|
||||||
self.thread_data_server.getData(25)
|
# # continue;
|
||||||
|
# # self.thread_data_server.getData(25)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Decoder Loop Error: {e}")
|
algo_log(f"Decoder Loop Error: {e}")
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
||||||
|
|
||||||
def decoder_SSVEP(self):
|
def decoder_SSVEP(self):
|
||||||
if self.zmqServer.StartDecode:
|
if self.zmqServer.StartDecode:
|
||||||
self.zmqServer.StartDecode = False
|
self.zmqServer.StartDecode = False
|
||||||
self.decodingSteps = 1
|
self.decodingSteps = 1
|
||||||
self.thread_data_server.ResetAll()
|
self.zmqServer.paradigmBuffer.ResetAllPara()
|
||||||
print('启动预测')
|
print('启动预测')
|
||||||
if self.thread_data_server.GetDataLenCount() < 50:
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
return
|
return
|
||||||
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
||||||
return
|
return
|
||||||
data = self.thread_data_server.getDataViaSSVEP(50)
|
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
|
||||||
data = data[:self.n_chan, :]
|
data = data[:self.n_chan, :]
|
||||||
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
||||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
@@ -293,7 +315,7 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
|
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
|
||||||
self.calculateCount = 0
|
self.calculateCount = 0
|
||||||
if self.decodingSteps == 3: # 发送解码后的信息
|
if self.decodingSteps == 3: # 发送解码后的信息
|
||||||
self.zmqClient.send_to_all('result', int(choosenNum))
|
self.zmqServer.broadcast_message('result', int(choosenNum))
|
||||||
self.decodingSteps = 0
|
self.decodingSteps = 0
|
||||||
print('发送给界面完成。')
|
print('发送给界面完成。')
|
||||||
|
|
||||||
@@ -312,19 +334,19 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('模型训练完成', formatted_time)
|
print('模型训练完成', formatted_time)
|
||||||
self.load_model = True
|
self.load_model = True
|
||||||
self.zmqClient.send_to_all('paradigm', 1)
|
self.zmqServer.broadcast_message('paradigm', 1)
|
||||||
|
|
||||||
'''训练阶段采集数据'''
|
'''训练阶段采集数据'''
|
||||||
if self.zmqServer.state_mode == 'train': # 训练状态
|
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||||
if self.zmqServer.StartTrain:
|
if self.zmqServer.StartTrain:
|
||||||
self.currentLabel = self.zmqServer.currentLabel
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
self.zmqServer.StartTrain = False
|
self.zmqServer.StartTrain = False
|
||||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||||
self.train_epoch[1] \
|
self.train_epoch[1] \
|
||||||
+ self.thread_data_server.event_inner_idx:
|
+ self.thread_data_server.event_inner_idx:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||||
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
|
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
|
||||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
|
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
|
||||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||||
@@ -348,7 +370,7 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('启动预测 ', formatted_time)
|
print('启动预测 ', formatted_time)
|
||||||
|
|
||||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||||
self.interval_epoch[1] \
|
self.interval_epoch[1] \
|
||||||
+ self.thread_data_server.event_inner_idx:
|
+ self.thread_data_server.event_inner_idx:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
@@ -360,8 +382,8 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||||
pad_eeg_test = np.zeros(
|
pad_eeg_test = np.zeros(
|
||||||
(data.shape[0], int((self.sample_length + 0.1) * self.fs)))
|
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
|
||||||
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data
|
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
|
||||||
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
||||||
if isinstance(choosenNum, np.ndarray):
|
if isinstance(choosenNum, np.ndarray):
|
||||||
choosenNum = choosenNum[0]
|
choosenNum = choosenNum[0]
|
||||||
@@ -371,7 +393,7 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
print('发送给界面完成。')
|
print('发送给界面完成。')
|
||||||
else: # 休息状态
|
else: # 休息状态
|
||||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
if self.thread_data_server.GetDataLenCount() < 25:
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
return
|
return
|
||||||
self.thread_data_server.getData(25)
|
self.thread_data_server.getData(25)
|
||||||
@@ -419,12 +441,12 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
if self.zmqServer.StartTrain:
|
if self.zmqServer.StartTrain:
|
||||||
self.currentLabel = self.zmqServer.currentLabel
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
self.zmqServer.StartTrain = False
|
self.zmqServer.StartTrain = False
|
||||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||||
self.interval_epoch[1] \
|
self.interval_epoch[1] \
|
||||||
+ self.thread_data_server.event_inner_idx:
|
+ self.thread_data_server.event_inner_idx:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||||
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
|
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
|
||||||
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
|
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
|
||||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||||
@@ -448,7 +470,7 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('启动预测 ', formatted_time)
|
print('启动预测 ', formatted_time)
|
||||||
|
|
||||||
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||||
self.interval_epoch[1] \
|
self.interval_epoch[1] \
|
||||||
+ self.thread_data_server.event_inner_idx:
|
+ self.thread_data_server.event_inner_idx:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
@@ -477,128 +499,128 @@ class Decoder_main(threading.Thread, device_type):
|
|||||||
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
||||||
else: # 休息状态
|
else: # 休息状态
|
||||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
if self.thread_data_server.GetDataLenCount() < 25:
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
return
|
return
|
||||||
self.thread_data_server.getData(25)
|
self.thread_data_server.getData(25)
|
||||||
|
|
||||||
def decoder_concentration(self):
|
# def decoder_concentration(self):
|
||||||
if self.zmqServer.state_mode == 'predict':
|
# if self.zmqServer.state_mode == 'predict':
|
||||||
if self.zmqServer.StartDecode:
|
# if self.zmqServer.StartDecode:
|
||||||
self.zmqServer.StartDecode = False
|
# self.zmqServer.StartDecode = False
|
||||||
self.thread_data_server.ResetAll()
|
# self.thread_data_server.ResetAll()
|
||||||
now = datetime.now()
|
# now = datetime.now()
|
||||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
# formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
print('启动专注力预测 ', formatted_time)
|
# print('启动专注力预测 ', formatted_time)
|
||||||
if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果
|
# if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果
|
||||||
time.sleep(0.005)
|
# time.sleep(0.005)
|
||||||
return
|
# return
|
||||||
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
# if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
||||||
return
|
# return
|
||||||
data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据
|
# data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据
|
||||||
result = self.calculate.queueOpt(data)
|
# result = self.calculate.queueOpt(data)
|
||||||
if result is not None:
|
# if result is not None:
|
||||||
self.zmqClient.send_to_all('result', int(result))
|
# self.zmqClient.send_to_all('result', int(result))
|
||||||
else: # 休息状态
|
# else: # 休息状态
|
||||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
if self.thread_data_server.GetDataLenCount() < 25:
|
# if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
time.sleep(0.005)
|
# time.sleep(0.005)
|
||||||
return
|
# return
|
||||||
self.thread_data_server.getData(25)
|
# self.thread_data_server.getData(25)
|
||||||
|
|
||||||
#### Blink detection #####
|
#### Blink detection #####
|
||||||
def check_double_blink(self, current_time):
|
# def check_double_blink(self, current_time):
|
||||||
"""
|
# """
|
||||||
检查是否检测到连续两次眨眼
|
# 检查是否检测到连续两次眨眼
|
||||||
@param current_time: 当前眨眼时间戳
|
# @param current_time: 当前眨眼时间戳
|
||||||
@return: True表示检测到连续两次眨眼
|
# @return: True表示检测到连续两次眨眼
|
||||||
"""
|
# """
|
||||||
if len(self.blink_timestamps) < 2:
|
# if len(self.blink_timestamps) < 2:
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
# 检查是否在去抖期内
|
# # 检查是否在去抖期内
|
||||||
if self.last_double_blink_time > 0:
|
# if self.last_double_blink_time > 0:
|
||||||
time_since_last_double_blink = current_time - self.last_double_blink_time
|
# time_since_last_double_blink = current_time - self.last_double_blink_time
|
||||||
if time_since_last_double_blink < self.double_blink_jitter:
|
# if time_since_last_double_blink < self.double_blink_jitter:
|
||||||
return False # 在去抖期内,忽略连续眨眼检测
|
# return False # 在去抖期内,忽略连续眨眼检测
|
||||||
last_time = self.blink_timestamps[-1] # 当前眨眼
|
# last_time = self.blink_timestamps[-1] # 当前眨眼
|
||||||
prev_time = self.blink_timestamps[-2] # 上次眨眼
|
# prev_time = self.blink_timestamps[-2] # 上次眨眼
|
||||||
|
|
||||||
interval = last_time - prev_time
|
# interval = last_time - prev_time
|
||||||
if interval <= self.double_blink_interval:
|
# if interval <= self.double_blink_interval:
|
||||||
return True
|
# return True
|
||||||
|
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
def process_blink_detection(self):
|
# def process_blink_detection(self):
|
||||||
"""
|
# """
|
||||||
在缓冲区数据上执行,单次眨眼检测
|
# 在缓冲区数据上执行,单次眨眼检测
|
||||||
"""
|
# """
|
||||||
if len(self.fp1_buffer) < self.window_samples:
|
# if len(self.fp1_buffer) < self.window_samples:
|
||||||
return
|
# return
|
||||||
|
|
||||||
fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
|
# fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
|
||||||
fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
|
# fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
|
||||||
# 计算FP1和FP2的平均
|
# # 计算FP1和FP2的平均
|
||||||
fp12_mean = (fp1_data + fp2_data) / 2.0
|
# fp12_mean = (fp1_data + fp2_data) / 2.0
|
||||||
# 带通滤波
|
# # 带通滤波
|
||||||
try:
|
# try:
|
||||||
fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
|
# fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print(f"Filter error: {e}")
|
# print(f"Filter error: {e}")
|
||||||
return
|
# return
|
||||||
F = np.diff(fp12_filtered)
|
# F = np.diff(fp12_filtered)
|
||||||
if len(F) < 3:
|
# if len(F) < 3:
|
||||||
return
|
# return
|
||||||
b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax)
|
# b, d, e = blink_detection(F, self.device_info['sample_rate'], self.Dmin, self.Dmax, self.EMin, self.EMax)
|
||||||
|
|
||||||
if b == 1:
|
# if b == 1:
|
||||||
samples_since_last = self.total_samples - self.last_blink_time
|
# samples_since_last = self.total_samples - self.last_blink_time
|
||||||
time_since_last_ms = (samples_since_last / self.fs) * 1000
|
# time_since_last_ms = (samples_since_last / self.device_info['sample_rate']) * 1000
|
||||||
if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
|
# if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
|
||||||
self.blink_count += 1
|
# self.blink_count += 1
|
||||||
self.last_blink_time = self.total_samples
|
# self.last_blink_time = self.total_samples
|
||||||
current_time = time.time()
|
# current_time = time.time()
|
||||||
self.blink_timestamps.append(current_time)
|
# self.blink_timestamps.append(current_time)
|
||||||
blink_event = {
|
# blink_event = {
|
||||||
'count': self.blink_count,
|
# 'count': self.blink_count,
|
||||||
'time': current_time,
|
# 'time': current_time,
|
||||||
'sample_index': self.total_samples,
|
# 'sample_index': self.total_samples,
|
||||||
'duration_ms': d,
|
# 'duration_ms': d,
|
||||||
'energy': e
|
# 'energy': e
|
||||||
}
|
# }
|
||||||
self.blink_events.append(blink_event)
|
# self.blink_events.append(blink_event)
|
||||||
self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
|
# self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
|
||||||
if self.check_double_blink(current_time):
|
# if self.check_double_blink(current_time):
|
||||||
self.double_blink_count += 1
|
# self.double_blink_count += 1
|
||||||
interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
|
# interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
|
||||||
double_blink_event = {
|
# double_blink_event = {
|
||||||
'double_blink_count': self.double_blink_count,
|
# 'double_blink_count': self.double_blink_count,
|
||||||
'blink1_time': self.blink_timestamps[-2],
|
# 'blink1_time': self.blink_timestamps[-2],
|
||||||
'blink2_time': self.blink_timestamps[-1],
|
# 'blink2_time': self.blink_timestamps[-1],
|
||||||
'interval': interval
|
# 'interval': interval
|
||||||
}
|
# }
|
||||||
self.double_blink_events.append(double_blink_event)
|
# self.double_blink_events.append(double_blink_event)
|
||||||
self.last_double_blink_time = current_time
|
# self.last_double_blink_time = current_time
|
||||||
self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
|
# self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
|
||||||
|
|
||||||
def decoder_blink(self):
|
# def decoder_blink(self):
|
||||||
if self.thread_data_server.GetDataLenCount() < 50:
|
# if self.thread_data_server.GetDataLenCount() < 50:
|
||||||
time.sleep(0.005)
|
# time.sleep(0.005)
|
||||||
return
|
# return
|
||||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
data = self.thread_data_server.get_blinkData(50)
|
# data = self.thread_data_server.get_blinkData(50)
|
||||||
fp1_data = data[0, :] # ch1 (相当于FP1)
|
# fp1_data = data[0, :] # ch1 (相当于FP1)
|
||||||
fp2_data = data[1, :] # ch2 (相当于FP2)
|
# fp2_data = data[1, :] # ch2 (相当于FP2)
|
||||||
for i in range(len(fp1_data)):
|
# for i in range(len(fp1_data)):
|
||||||
self.fp1_buffer.append(fp1_data[i])
|
# self.fp1_buffer.append(fp1_data[i])
|
||||||
self.fp2_buffer.append(fp2_data[i])
|
# self.fp2_buffer.append(fp2_data[i])
|
||||||
self.total_samples += 1
|
# self.total_samples += 1
|
||||||
self.sample_counter += 1
|
# self.sample_counter += 1
|
||||||
|
|
||||||
if self.sample_counter >= self.step_samples:
|
# if self.sample_counter >= self.step_samples:
|
||||||
self.process_blink_detection()
|
# self.process_blink_detection()
|
||||||
self.sample_counter = 0
|
# self.sample_counter = 0
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
'''
|
'''
|
||||||
|
|||||||
@@ -63,8 +63,53 @@ class ParadigmRingBuffer:
|
|||||||
获取最新缓存中每个通道的数量
|
获取最新缓存中每个通道的数量
|
||||||
@return:
|
@return:
|
||||||
'''
|
'''
|
||||||
return self.nUpdate
|
return self.nUpdate
|
||||||
|
|
||||||
|
# ========== 各范式数据访问接口 ==========
|
||||||
|
def get_MIData(self):
|
||||||
|
"""获取MI导联数据 (21通道 + 事件)"""
|
||||||
|
data = self.getData(self.GetDataLenCount())
|
||||||
|
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def get_SSMVEPData(self):
|
||||||
|
"""获取SSMVEP导联数据 (8通道 + 事件)"""
|
||||||
|
data = self.getData(self.GetDataLenCount())
|
||||||
|
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def getDataViaSSVEP(self, count):
|
||||||
|
"""获取SSVEP数据 (8通道 + 事件)"""
|
||||||
|
data = self.getData(count)
|
||||||
|
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def get_concentrateData(self, count):
|
||||||
|
"""获取专注力数据 (2通道)"""
|
||||||
|
data = self.getData(count)
|
||||||
|
rows_to_extract = [0, 1]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
|
def get_blinkData(self, count):
|
||||||
|
"""获取眨眼数据 (2通道)"""
|
||||||
|
data = self.getData(count)
|
||||||
|
rows_to_extract = [0, 1]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
if data.shape[1] > 0:
|
||||||
|
return data[row_to_select, :]
|
||||||
|
return np.zeros((len(rows_to_extract), 0))
|
||||||
|
|
||||||
# reset buffer
|
# reset buffer
|
||||||
def resetAllPara(self):
|
def resetAllPara(self):
|
||||||
@@ -72,6 +117,4 @@ class ParadigmRingBuffer:
|
|||||||
self.currentPtr = 0
|
self.currentPtr = 0
|
||||||
self.readPtr = 0 # add by lizhenhua 清空读指针
|
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||||
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
122
Zmq/zmqServer.py
122
Zmq/zmqServer.py
@@ -1,16 +1,22 @@
|
|||||||
|
import ast
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
|
||||||
import threading
|
import threading
|
||||||
import json
|
import json
|
||||||
import queue
|
import queue
|
||||||
|
from typing import Dict
|
||||||
# from Device.SunnyLinker import SunnyLinker64
|
# from Device.SunnyLinker import SunnyLinker64
|
||||||
from dataBuffer import ParadigmRingBuffer
|
from dataBuffer import ParadigmRingBuffer
|
||||||
from filterProcess import FilterRingBuffer
|
from filterProcess import FilterRingBuffer
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
from logs.log import algo_log
|
from logs.log import algo_log
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
class zmqServer(threading.Thread):
|
class zmqServer(threading.Thread):
|
||||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
|
self.device_info = device_info
|
||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.cmd_port = cmd_port # 命令交互端口
|
self.cmd_port = cmd_port # 命令交互端口
|
||||||
self.data_port = data_port # 数据接收端口
|
self.data_port = data_port # 数据接收端口
|
||||||
@@ -28,8 +34,8 @@ class zmqServer(threading.Thread):
|
|||||||
self.daemon = True
|
self.daemon = True
|
||||||
|
|
||||||
# 范式数据缓存
|
# 范式数据缓存
|
||||||
self.paradigmBuffer = ParadigmRingBuffer(66, 2500)
|
self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
|
||||||
self.filterBuffer = FilterRingBuffer(66, 2500)
|
self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
|
||||||
|
|
||||||
|
|
||||||
# 命令与数据通信
|
# 命令与数据通信
|
||||||
@@ -64,6 +70,77 @@ class zmqServer(threading.Thread):
|
|||||||
self.cmd_clients = set() # 命令端口客户端ID
|
self.cmd_clients = set() # 命令端口客户端ID
|
||||||
self.data_clients = set() # 数据端口客户端ID
|
self.data_clients = set() # 数据端口客户端ID
|
||||||
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
|
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
|
||||||
|
|
||||||
|
|
||||||
|
# 范式buffer参数, 事件检测相关
|
||||||
|
self._event_lock = threading.Lock()
|
||||||
|
self._epoch_finished = False
|
||||||
|
self._event_inner_idx = -1
|
||||||
|
self.pack_contain_event = False
|
||||||
|
self.predict_event = 99
|
||||||
|
self.events = [1, 2, self.predict_event]
|
||||||
|
self.count_events = {}
|
||||||
|
self.latency = 50
|
||||||
|
self.train_latency = 50
|
||||||
|
self._interval_inited = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interval_inited(self):
|
||||||
|
return self._interval_inited
|
||||||
|
|
||||||
|
@interval_inited.setter
|
||||||
|
def interval_inited(self, value):
|
||||||
|
self._interval_inited = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def epoch_finished(self):
|
||||||
|
with self._event_lock:
|
||||||
|
return self._epoch_finished
|
||||||
|
|
||||||
|
@epoch_finished.setter
|
||||||
|
def epoch_finished(self, value):
|
||||||
|
with self._event_lock:
|
||||||
|
self._epoch_finished = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def event_inner_idx(self):
|
||||||
|
with self._event_lock:
|
||||||
|
return self._event_inner_idx
|
||||||
|
|
||||||
|
@event_inner_idx.setter
|
||||||
|
def event_inner_idx(self, value):
|
||||||
|
with self._event_lock:
|
||||||
|
self._event_inner_idx = value
|
||||||
|
|
||||||
|
def interval_init(self, decoder_class):
|
||||||
|
if decoder_class == 'ssmvep':
|
||||||
|
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||||
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
|
||||||
|
self.train_epoch = [int(self.interval_epoch[0]),
|
||||||
|
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
|
||||||
|
self.latency = (self.interval_epoch[
|
||||||
|
1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
|
||||||
|
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
||||||
|
|
||||||
|
elif decoder_class == 'mi':
|
||||||
|
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||||
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
|
||||||
|
self.train_epoch = self.interval_epoch.copy()
|
||||||
|
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;
|
||||||
|
self.train_latency = self.latency
|
||||||
|
|
||||||
|
print('时间窗:', (interval_epoch))
|
||||||
|
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
|
||||||
|
self.event_inner_idx = -1 # event在5位数据包内部的idx
|
||||||
|
self.epoch_finished = False # 接收epoch是否完整
|
||||||
|
self.pack_contain_event = False # 当前包是否含有event
|
||||||
|
self.predict_event = 99
|
||||||
|
self.events = [1, 2, self.predict_event]
|
||||||
|
self.interval_inited = True
|
||||||
|
# if getattr(self, 'serial', None) and self.serial.is_open:
|
||||||
|
# self.serial.close()
|
||||||
|
# self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口
|
||||||
|
|
||||||
|
|
||||||
def broadcast_message(self, method, params):
|
def broadcast_message(self, method, params):
|
||||||
"""Put message into queue to be sent to all command clients"""
|
"""Put message into queue to be sent to all command clients"""
|
||||||
@@ -78,15 +155,15 @@ class zmqServer(threading.Thread):
|
|||||||
# 注册新的命令客户端
|
# 注册新的命令客户端
|
||||||
if ident not in self.cmd_clients:
|
if ident not in self.cmd_clients:
|
||||||
self.cmd_clients.add(ident)
|
self.cmd_clients.add(ident)
|
||||||
print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
|
algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
|
||||||
|
|
||||||
# 解析消息
|
# 解析消息
|
||||||
try:
|
try:
|
||||||
message = json.loads(message_bytes.decode('utf-8'))
|
message = json.loads(message_bytes.decode('utf-8'))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print(f"Invalid JSON from CMD client {ident}")
|
algo_log(f"Invalid JSON from CMD client {ident}")
|
||||||
continue
|
return
|
||||||
print(f"Received CMD request: {message}")
|
algo_log(f"Received CMD request: {message}")
|
||||||
|
|
||||||
method = message.get("method")
|
method = message.get("method")
|
||||||
params = message.get("params")
|
params = message.get("params")
|
||||||
@@ -94,37 +171,40 @@ class zmqServer(threading.Thread):
|
|||||||
# 原有命令处理逻辑
|
# 原有命令处理逻辑
|
||||||
if method == "sync":
|
if method == "sync":
|
||||||
self.state_mode = 'sync'
|
self.state_mode = 'sync'
|
||||||
if method == "targetFreqs":
|
elif method == "targetFreqs":
|
||||||
if not isinstance(params, list):
|
if not isinstance(params, list):
|
||||||
print('targetFreqs must be a list')
|
algo_log(f"targetFreqs must be a list")
|
||||||
continue
|
return
|
||||||
if params != self.targetFreqs:
|
if params != self.targetFreqs:
|
||||||
self.targetFreqs = params
|
self.targetFreqs = params
|
||||||
self.changeTarget = True
|
self.changeTarget = True
|
||||||
if method == "decoderClass":
|
elif method == "decoderClass":
|
||||||
if not isinstance(params, str):
|
if not isinstance(params, str):
|
||||||
print('decoderClass must be a str')
|
algo_log(f"decoderClass must be a str")
|
||||||
continue
|
return
|
||||||
if params != self.decoder_class:
|
if params != self.decoder_class:
|
||||||
self.decoder_class = params
|
self.decoder_class = params
|
||||||
self.decoder_switch = True
|
self.decoder_switch = True
|
||||||
if method == "getReport":
|
elif method == "train":#训练状态
|
||||||
self.getReport = True
|
|
||||||
if method == "train":#训练状态
|
|
||||||
self.state_mode = 'train'
|
self.state_mode = 'train'
|
||||||
self.StartTrain = True
|
self.StartTrain = True
|
||||||
self.currentLabel = params # 当前刺激端的训练标签
|
self.currentLabel = params # 当前刺激端的训练标签
|
||||||
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
# self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||||
elif method == "predict":#预测状态
|
elif method == "predict":#预测状态
|
||||||
self.state_mode = 'predict'
|
self.state_mode = 'predict'
|
||||||
if params == 1: #开始解码
|
if params == 1: #开始解码
|
||||||
self.StartDecode = True
|
self.StartDecode = True
|
||||||
self.sunnyLinker.push_trigger(0x63)
|
# self.sunnyLinker.push_trigger(0x63)
|
||||||
elif params == 2: #停止解码
|
elif params == 2: #停止解码
|
||||||
self.IsExitApp = True
|
self.IsExitApp = True
|
||||||
self.running = False
|
self.running = False
|
||||||
elif method == "rest": #休息状态
|
elif method == "rest": #休息状态
|
||||||
self.state_mode = 'rest'
|
self.state_mode = 'rest'
|
||||||
|
else:
|
||||||
|
algo_log(f"未知命令:{method}", level="WARNING")
|
||||||
|
|
||||||
|
# elif method == "getReport":
|
||||||
|
# self.getReport = True
|
||||||
# elif method == "impedance":
|
# elif method == "impedance":
|
||||||
# if params == 1:
|
# if params == 1:
|
||||||
# self.open_Impedance = True # 开启阻抗
|
# self.open_Impedance = True # 开启阻抗
|
||||||
@@ -153,7 +233,7 @@ class zmqServer(threading.Thread):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节,与int32字节数相同)
|
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节,与int32字节数相同)
|
||||||
EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节
|
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节
|
||||||
if len(data_bytes) != EXPECTED_BYTES:
|
if len(data_bytes) != EXPECTED_BYTES:
|
||||||
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节")
|
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节")
|
||||||
return
|
return
|
||||||
@@ -162,7 +242,7 @@ class zmqServer(threading.Thread):
|
|||||||
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
|
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
|
||||||
data_np = np.frombuffer(data_bytes, dtype=np.float32)
|
data_np = np.frombuffer(data_bytes, dtype=np.float32)
|
||||||
# 重塑为上位机原始维度
|
# 重塑为上位机原始维度
|
||||||
data_np = data_np.reshape(5, 66)
|
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
|
||||||
# 转置为(通道数, 采样点数)标准格式,转换为float64保证滤波运算精度
|
# 转置为(通道数, 采样点数)标准格式,转换为float64保证滤波运算精度
|
||||||
data_np = data_np.T.astype(np.float64)
|
data_np = data_np.T.astype(np.float64)
|
||||||
|
|
||||||
@@ -215,7 +295,7 @@ class zmqServer(threading.Thread):
|
|||||||
self._process_send_queue()
|
self._process_send_queue()
|
||||||
|
|
||||||
# 2. 轮训监听两个Socket的输入事件(10ms超时,避免阻塞)
|
# 2. 轮训监听两个Socket的输入事件(10ms超时,避免阻塞)
|
||||||
socks = dict(self.poller.poll(10))
|
socks = dict(self.poller.poll(50))
|
||||||
|
|
||||||
# 处理命令端口消息
|
# 处理命令端口消息
|
||||||
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
||||||
|
|||||||
@@ -24,10 +24,11 @@ console_output = 1
|
|||||||
|
|
||||||
; 64 导设备配置
|
; 64 导设备配置
|
||||||
[device_type_1]
|
[device_type_1]
|
||||||
device_sample_rate = 250
|
sample_rate = 250
|
||||||
device_channel_nums = 66
|
frame_points = 5
|
||||||
device_channel_names = ['FP1', 'FP2', 'FC1', 'FC2', 'CP1', 'CP2', 'F3', 'F4', 'P3', 'P4', 'O1', 'O2', 'FT9', 'FT10', 'F7', 'F8', 'TP9', 'TP10', 'AF4', 'PO8', 'PZ', 'FCZ']
|
channel_nums = 66
|
||||||
device_channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18]
|
channel_names = ['FP1', 'FP2', 'FC1', 'FC2', 'CP1', 'CP2', 'F3', 'F4', 'P3', 'P4', 'O1', 'O2', 'FT9', 'FT10', 'F7', 'F8', 'TP9', 'TP10', 'AF4', 'PO8', 'PZ', 'FCZ']
|
||||||
|
channel_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,29 +9,28 @@ from PubLibrary.InifileHelper import IniRead
|
|||||||
|
|
||||||
def get_device_info(device_type):
|
def get_device_info(device_type):
|
||||||
|
|
||||||
|
|
||||||
section = f'device_type_{device_type}'
|
section = f'device_type_{device_type}'
|
||||||
device_info = {
|
device_info = {
|
||||||
'device_sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250,
|
'sample_rate': int(IniRead(section, 'sample_rate')) if IniRead(section, 'sample_rate') is not None else 250,
|
||||||
|
'channel_nums': int(IniRead(section, 'channel_nums')) if IniRead(section, 'channel_nums') is not None else 66,
|
||||||
''
|
'channel_names': IniRead(section, 'channel_names') if IniRead(section, 'channel_names') is not None else None,
|
||||||
|
'channel_index': IniRead(section, 'channel_index') if IniRead(section, 'channel_index') is not None else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return device_info
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if not is_program_running():
|
if not is_program_running():
|
||||||
# 解析命令行参数
|
# 解析命令行参数
|
||||||
parser = argparse.ArgumentParser(description="EEG Decoder Application")
|
# parser = argparse.ArgumentParser(description="EEG Decoder Application")
|
||||||
parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type")
|
# parser.add_argument('-dt', '-t','--device-type', type=int, default=None, help="Device Type")
|
||||||
# parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
|
# parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
|
||||||
# parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
|
# parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
|
||||||
# parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
|
# parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
|
||||||
# parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
|
# parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
|
||||||
|
# args = parser.parse_args()
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
device_info= get_device_info(args.device_type)
|
|
||||||
|
|
||||||
|
|
||||||
decoder = Decoder_main(device_info=device_info)
|
|
||||||
# decoder.connect(
|
# decoder.connect(
|
||||||
# device_type=args.device_type,
|
# device_type=args.device_type,
|
||||||
# device_host=args.device_host,
|
# device_host=args.device_host,
|
||||||
@@ -40,6 +39,9 @@ if __name__ == "__main__":
|
|||||||
# upper_port=args.upper_port
|
# upper_port=args.upper_port
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
device_info= get_device_info(1)
|
||||||
|
decoder = Decoder_main(device_info=device_info)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decoder.start()
|
decoder.start()
|
||||||
while not decoder.zmqServer.IsExitApp:
|
while not decoder.zmqServer.IsExitApp:
|
||||||
|
|||||||
Reference in New Issue
Block a user