add buffer
This commit is contained in:
@@ -63,8 +63,53 @@ class ParadigmRingBuffer:
|
||||
获取最新缓存中每个通道的数量
|
||||
@return:
|
||||
'''
|
||||
return self.nUpdate
|
||||
return self.nUpdate
|
||||
|
||||
# ========== 各范式数据访问接口 ==========
|
||||
def get_MIData(self):
|
||||
"""获取MI导联数据 (21通道 + 事件)"""
|
||||
data = self.getData(self.GetDataLenCount())
|
||||
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_SSMVEPData(self):
|
||||
"""获取SSMVEP导联数据 (8通道 + 事件)"""
|
||||
data = self.getData(self.GetDataLenCount())
|
||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def getDataViaSSVEP(self, count):
|
||||
"""获取SSVEP数据 (8通道 + 事件)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_concentrateData(self, count):
|
||||
"""获取专注力数据 (2通道)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [0, 1]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
def get_blinkData(self, count):
|
||||
"""获取眨眼数据 (2通道)"""
|
||||
data = self.getData(count)
|
||||
rows_to_extract = [0, 1]
|
||||
row_to_select = np.array(rows_to_extract)
|
||||
if data.shape[1] > 0:
|
||||
return data[row_to_select, :]
|
||||
return np.zeros((len(rows_to_extract), 0))
|
||||
|
||||
# reset buffer
|
||||
def resetAllPara(self):
|
||||
@@ -72,6 +117,4 @@ class ParadigmRingBuffer:
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||
|
||||
|
||||
|
||||
|
||||
122
Zmq/zmqServer.py
122
Zmq/zmqServer.py
@@ -1,16 +1,22 @@
|
||||
import ast
|
||||
import numpy as np
|
||||
import zmq
|
||||
import threading
|
||||
import json
|
||||
import queue
|
||||
from typing import Dict
|
||||
# from Device.SunnyLinker import SunnyLinker64
|
||||
from dataBuffer import ParadigmRingBuffer
|
||||
from filterProcess import FilterRingBuffer
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
from logs.log import algo_log
|
||||
|
||||
import zmq
|
||||
|
||||
class zmqServer(threading.Thread):
|
||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
||||
threading.Thread.__init__(self)
|
||||
self.device_info = device_info
|
||||
|
||||
self.host = host
|
||||
self.cmd_port = cmd_port # 命令交互端口
|
||||
self.data_port = data_port # 数据接收端口
|
||||
@@ -28,8 +34,8 @@ class zmqServer(threading.Thread):
|
||||
self.daemon = True
|
||||
|
||||
# 范式数据缓存
|
||||
self.paradigmBuffer = ParadigmRingBuffer(66, 2500)
|
||||
self.filterBuffer = FilterRingBuffer(66, 2500)
|
||||
self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
|
||||
self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
|
||||
|
||||
|
||||
# 命令与数据通信
|
||||
@@ -64,6 +70,77 @@ class zmqServer(threading.Thread):
|
||||
self.cmd_clients = set() # 命令端口客户端ID
|
||||
self.data_clients = set() # 数据端口客户端ID
|
||||
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
|
||||
|
||||
|
||||
# 范式buffer参数, 事件检测相关
|
||||
self._event_lock = threading.Lock()
|
||||
self._epoch_finished = False
|
||||
self._event_inner_idx = -1
|
||||
self.pack_contain_event = False
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.count_events = {}
|
||||
self.latency = 50
|
||||
self.train_latency = 50
|
||||
self._interval_inited = False
|
||||
|
||||
@property
|
||||
def interval_inited(self):
|
||||
return self._interval_inited
|
||||
|
||||
@interval_inited.setter
|
||||
def interval_inited(self, value):
|
||||
self._interval_inited = value
|
||||
|
||||
@property
|
||||
def epoch_finished(self):
|
||||
with self._event_lock:
|
||||
return self._epoch_finished
|
||||
|
||||
@epoch_finished.setter
|
||||
def epoch_finished(self, value):
|
||||
with self._event_lock:
|
||||
self._epoch_finished = value
|
||||
|
||||
@property
|
||||
def event_inner_idx(self):
|
||||
with self._event_lock:
|
||||
return self._event_inner_idx
|
||||
|
||||
@event_inner_idx.setter
|
||||
def event_inner_idx(self, value):
|
||||
with self._event_lock:
|
||||
self._event_inner_idx = value
|
||||
|
||||
def interval_init(self, decoder_class):
|
||||
if decoder_class == 'ssmvep':
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
|
||||
self.train_epoch = [int(self.interval_epoch[0]),
|
||||
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
|
||||
self.latency = (self.interval_epoch[
|
||||
1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
|
||||
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
||||
|
||||
elif decoder_class == 'mi':
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
|
||||
self.train_epoch = self.interval_epoch.copy()
|
||||
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;
|
||||
self.train_latency = self.latency
|
||||
|
||||
print('时间窗:', (interval_epoch))
|
||||
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
|
||||
self.event_inner_idx = -1 # event在5位数据包内部的idx
|
||||
self.epoch_finished = False # 接收epoch是否完整
|
||||
self.pack_contain_event = False # 当前包是否含有event
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.interval_inited = True
|
||||
# if getattr(self, 'serial', None) and self.serial.is_open:
|
||||
# self.serial.close()
|
||||
# self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口
|
||||
|
||||
|
||||
def broadcast_message(self, method, params):
|
||||
"""Put message into queue to be sent to all command clients"""
|
||||
@@ -78,15 +155,15 @@ class zmqServer(threading.Thread):
|
||||
# 注册新的命令客户端
|
||||
if ident not in self.cmd_clients:
|
||||
self.cmd_clients.add(ident)
|
||||
print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
|
||||
algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
|
||||
|
||||
# 解析消息
|
||||
try:
|
||||
message = json.loads(message_bytes.decode('utf-8'))
|
||||
except json.JSONDecodeError:
|
||||
print(f"Invalid JSON from CMD client {ident}")
|
||||
continue
|
||||
print(f"Received CMD request: {message}")
|
||||
algo_log(f"Invalid JSON from CMD client {ident}")
|
||||
return
|
||||
algo_log(f"Received CMD request: {message}")
|
||||
|
||||
method = message.get("method")
|
||||
params = message.get("params")
|
||||
@@ -94,37 +171,40 @@ class zmqServer(threading.Thread):
|
||||
# 原有命令处理逻辑
|
||||
if method == "sync":
|
||||
self.state_mode = 'sync'
|
||||
if method == "targetFreqs":
|
||||
elif method == "targetFreqs":
|
||||
if not isinstance(params, list):
|
||||
print('targetFreqs must be a list')
|
||||
continue
|
||||
algo_log(f"targetFreqs must be a list")
|
||||
return
|
||||
if params != self.targetFreqs:
|
||||
self.targetFreqs = params
|
||||
self.changeTarget = True
|
||||
if method == "decoderClass":
|
||||
elif method == "decoderClass":
|
||||
if not isinstance(params, str):
|
||||
print('decoderClass must be a str')
|
||||
continue
|
||||
algo_log(f"decoderClass must be a str")
|
||||
return
|
||||
if params != self.decoder_class:
|
||||
self.decoder_class = params
|
||||
self.decoder_switch = True
|
||||
if method == "getReport":
|
||||
self.getReport = True
|
||||
if method == "train":#训练状态
|
||||
elif method == "train":#训练状态
|
||||
self.state_mode = 'train'
|
||||
self.StartTrain = True
|
||||
self.currentLabel = params # 当前刺激端的训练标签
|
||||
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||
# self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||
elif method == "predict":#预测状态
|
||||
self.state_mode = 'predict'
|
||||
if params == 1: #开始解码
|
||||
self.StartDecode = True
|
||||
self.sunnyLinker.push_trigger(0x63)
|
||||
# self.sunnyLinker.push_trigger(0x63)
|
||||
elif params == 2: #停止解码
|
||||
self.IsExitApp = True
|
||||
self.running = False
|
||||
elif method == "rest": #休息状态
|
||||
self.state_mode = 'rest'
|
||||
else:
|
||||
algo_log(f"未知命令:{method}", level="WARNING")
|
||||
|
||||
# elif method == "getReport":
|
||||
# self.getReport = True
|
||||
# elif method == "impedance":
|
||||
# if params == 1:
|
||||
# self.open_Impedance = True # 开启阻抗
|
||||
@@ -153,7 +233,7 @@ class zmqServer(threading.Thread):
|
||||
|
||||
try:
|
||||
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节,与int32字节数相同)
|
||||
EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节
|
||||
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节
|
||||
if len(data_bytes) != EXPECTED_BYTES:
|
||||
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节")
|
||||
return
|
||||
@@ -162,7 +242,7 @@ class zmqServer(threading.Thread):
|
||||
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
|
||||
data_np = np.frombuffer(data_bytes, dtype=np.float32)
|
||||
# 重塑为上位机原始维度
|
||||
data_np = data_np.reshape(5, 66)
|
||||
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
|
||||
# 转置为(通道数, 采样点数)标准格式,转换为float64保证滤波运算精度
|
||||
data_np = data_np.T.astype(np.float64)
|
||||
|
||||
@@ -215,7 +295,7 @@ class zmqServer(threading.Thread):
|
||||
self._process_send_queue()
|
||||
|
||||
# 2. 轮训监听两个Socket的输入事件(10ms超时,避免阻塞)
|
||||
socks = dict(self.poller.poll(10))
|
||||
socks = dict(self.poller.poll(50))
|
||||
|
||||
# 处理命令端口消息
|
||||
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
||||
|
||||
Reference in New Issue
Block a user