Files
bci_algo/Zmq/zmqServer.py

375 lines
16 KiB
Python
Raw Normal View History

2026-06-06 14:40:07 +08:00
import ast
2026-06-05 09:34:29 +08:00
import numpy as np
import threading
import json
import queue
2026-06-06 14:40:07 +08:00
from typing import Dict
2026-06-07 11:05:24 +08:00
import datetime
import time
2026-06-06 09:16:49 +08:00
# from Device.SunnyLinker import SunnyLinker64
2026-06-06 15:13:23 +08:00
from Zmq.dataBuffer import ParadigmRingBuffer
from Zmq.filterProcess import FilterRingBuffer
2026-06-06 14:40:07 +08:00
from PubLibrary.InifileHelper import IniRead
2026-06-06 09:16:49 +08:00
from logs.log import algo_log
2026-06-05 09:34:29 +08:00
2026-06-06 14:40:07 +08:00
import zmq
2026-06-05 09:34:29 +08:00
class zmqServer(threading.Thread):
2026-06-06 09:16:49 +08:00
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
2026-06-05 09:34:29 +08:00
threading.Thread.__init__(self)
2026-06-06 14:40:07 +08:00
self.device_info = device_info
2026-06-05 09:34:29 +08:00
self.host = host
2026-06-06 09:16:49 +08:00
self.cmd_port = cmd_port # 命令交互端口
self.data_port = data_port # 数据接收端口
2026-06-05 09:34:29 +08:00
self.running = False
2026-06-06 09:16:49 +08:00
# 原有业务状态变量
# self.get_Impedance = False # 是否返回阻抗值
2026-06-06 17:08:09 +08:00
self.open_Impedance = False # 是否开启阻抗检测功能
2026-06-06 09:16:49 +08:00
self.StartDecode = False # false 停止解码true=开始解码
self.StartTrain = False # False未进入训练状态True处于训练状态
self.state_mode = None # 'train'为训练状态rest'为休息状态,'test'为测试状态
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
self.IsExitApp = False # 当socket收到2的时候就置为True代表要退出系统了。
# self.getReport = False # 获取训练报告内容
2026-06-05 09:34:29 +08:00
self.daemon = True
2026-06-06 09:16:49 +08:00
# 范式数据缓存
2026-06-06 14:40:07 +08:00
self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
2026-06-06 17:08:09 +08:00
self.paradigmBufferLock= threading.Lock()
2026-06-06 09:16:49 +08:00
# 命令与数据通信
2026-06-05 09:34:29 +08:00
self.context = zmq.Context()
2026-06-06 09:16:49 +08:00
# 指令通道 (8099) - ROUTER短JSON命令低频率
self.cmd_socket = self.context.socket(zmq.ROUTER)
2026-06-06 15:13:23 +08:00
# 通用套接字选项:仍在 SocketOption 中
self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100)
self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100)
2026-06-06 09:16:49 +08:00
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
# 数据通道 (8100) - ROUTER高频脑电二进制流
self.data_socket = self.context.socket(zmq.ROUTER)
2026-06-06 15:13:23 +08:00
self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500)
2026-06-06 09:16:49 +08:00
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
# Poller 轮训器(保持不变)
self.poller = zmq.Poller()
self.poller.register(self.cmd_socket, zmq.POLLIN)
self.poller.register(self.data_socket, zmq.POLLIN)
2026-06-06 15:13:23 +08:00
2026-06-06 09:16:49 +08:00
# 业务变量
2026-06-05 09:34:29 +08:00
self.targetFreqs = []
self.changeTarget = False # 更换目标频率
2026-06-06 09:16:49 +08:00
# self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类已在Decoder实例化
2026-06-05 09:34:29 +08:00
self.labels = [0x01, 0x02,0x03]
self.decoder_switch = False #更换解码器
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
2026-06-06 09:16:49 +08:00
# 客户端管理 - 区分命令/数据客户端
self.cmd_clients = set() # 命令端口客户端ID
self.data_clients = set() # 数据端口客户端ID
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
2026-06-06 14:40:07 +08:00
# 范式buffer参数, 事件检测相关
self._event_lock = threading.Lock()
2026-06-07 11:05:24 +08:00
2026-06-06 14:40:07 +08:00
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.latency = 50
self.train_latency = 50
2026-06-07 11:05:24 +08:00
self.count_events = {}
self.epoch_finished = False
self.pack_contain_event = False
self.event_inner_idx = -1
self.interval_inited = False
2026-06-06 14:40:07 +08:00
2026-06-06 17:08:09 +08:00
def reset_state(self):
"""清空采集器状态和缓存数据"""
with self.paradigmBufferLock:
self.paradigmBuffer.resetAllPara()
self.count_events = {}
self.epoch_finished = False
self.pack_contain_event = False
self.event_inner_idx = -1
self.interval_inited = False
2026-06-06 14:40:07 +08:00
def interval_init(self, decoder_class):
if decoder_class == 'ssmvep':
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]),
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
self.latency = (self.interval_epoch[
1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
elif decoder_class == 'mi':
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
self.train_epoch = self.interval_epoch.copy()
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记5代表每次解包得到的5位采样点;
self.train_latency = self.latency
print('时间窗:', (interval_epoch))
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
self.event_inner_idx = -1 # event在5位数据包内部的idx
self.epoch_finished = False # 接收epoch是否完整
self.pack_contain_event = False # 当前包是否含有event
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.interval_inited = True
2026-06-05 09:34:29 +08:00
def broadcast_message(self, method, params):
2026-06-06 09:16:49 +08:00
"""Put message into queue to be sent to all command clients"""
2026-06-05 09:34:29 +08:00
self.send_queue.put((method, params))
2026-06-06 09:16:49 +08:00
def _handle_cmd_message(self, frames):
"""处理命令端口消息(原有命令交互逻辑)"""
if len(frames) < 3:
return
ident, _, message_bytes = frames[:3]
# 注册新的命令客户端
if ident not in self.cmd_clients:
self.cmd_clients.add(ident)
2026-06-06 14:40:07 +08:00
algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
2026-06-06 09:16:49 +08:00
# 解析消息
try:
message = json.loads(message_bytes.decode('utf-8'))
except json.JSONDecodeError:
2026-06-06 14:40:07 +08:00
algo_log(f"Invalid JSON from CMD client {ident}")
return
algo_log(f"Received CMD request: {message}")
2026-06-06 09:16:49 +08:00
method = message.get("method")
params = message.get("params")
# 原有命令处理逻辑
if method == "sync":
self.state_mode = 'sync'
2026-06-06 14:40:07 +08:00
elif method == "targetFreqs":
2026-06-06 09:16:49 +08:00
if not isinstance(params, list):
2026-06-06 14:40:07 +08:00
algo_log(f"targetFreqs must be a list")
return
2026-06-06 09:16:49 +08:00
if params != self.targetFreqs:
self.targetFreqs = params
self.changeTarget = True
2026-06-06 14:40:07 +08:00
elif method == "decoderClass":
2026-06-06 09:16:49 +08:00
if not isinstance(params, str):
2026-06-06 14:40:07 +08:00
algo_log(f"decoderClass must be a str")
return
2026-06-06 09:16:49 +08:00
if params != self.decoder_class:
self.decoder_class = params
self.decoder_switch = True
2026-06-06 14:40:07 +08:00
elif method == "train":#训练状态
2026-06-06 09:16:49 +08:00
self.state_mode = 'train'
self.StartTrain = True
self.currentLabel = params # 当前刺激端的训练标签
2026-06-06 14:40:07 +08:00
# self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
2026-06-06 09:16:49 +08:00
elif method == "predict":#预测状态
self.state_mode = 'predict'
if params == 1: #开始解码
self.StartDecode = True
2026-06-06 14:40:07 +08:00
# self.sunnyLinker.push_trigger(0x63)
2026-06-06 09:16:49 +08:00
elif params == 2: #停止解码
self.IsExitApp = True
self.running = False
elif method == "rest": #休息状态
self.state_mode = 'rest'
2026-06-06 17:08:09 +08:00
elif method == "impedance":
if params == 1:
self.open_Impedance = True # 开启阻抗
# self.get_Impedance = True # 返回阻抗
elif params == 2:
self.open_Impedance = False # 关闭阻抗
2026-06-06 14:40:07 +08:00
else:
algo_log(f"未知命令:{method}", level="WARNING")
2026-06-06 17:08:09 +08:00
2026-06-06 14:40:07 +08:00
# elif method == "getReport":
# self.getReport = True
2026-06-06 17:08:09 +08:00
2026-06-06 09:16:49 +08:00
# elif params == 2:
# self.open_Impedance = False # 关闭阻抗
# self.get_Impedance = False # 停止返回阻抗
def _handle_data_message(self, frames):
"""
处理8100端口原始脑电二进制数据
固定格式上位机发送 (5,66) float32 二维数组字节流已转换为微伏物理量 转置为 (66,5) 写入双缓冲区
"""
2026-06-06 15:53:50 +08:00
# 1. 校验ZMQ消息帧完整性ROUTER接收DEALER消息的帧格式[客户端ID, 发送方ID, 空帧, 数据帧]
if len(frames) < 4: # 至少需要4帧
algo_log(f"Invalid data frame: 帧数量不足期望≥4实际{len(frames)}", level="ERROR")
2026-06-06 09:16:49 +08:00
return
2026-06-06 15:53:50 +08:00
# 2. 正确解析帧适配DEALER→ROUTER的帧格式
client_ident, sender_ident, empty_sep, data_bytes = frames[:4]
if empty_sep != b'': # 校验空分隔帧
algo_log(f"Invalid frame separator: 期望空字节,实际{empty_sep}", level="ERROR")
return
2026-06-06 09:16:49 +08:00
2026-06-06 15:53:50 +08:00
# 3. 客户端管理(单客户端场景,自动更新最新身份)
if client_ident not in self.data_clients:
self.data_clients.add(client_ident)
self.current_data_client = client_ident # 保存唯一客户端身份,用于后续回复滤波结果
print(f"[INFO] 新数据客户端连接成功:{client_ident}")
2026-06-06 09:16:49 +08:00
try:
2026-06-06 15:53:50 +08:00
# 4. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节
2026-06-06 14:40:07 +08:00
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节
2026-06-06 09:16:49 +08:00
if len(data_bytes) != EXPECTED_BYTES:
2026-06-06 15:53:50 +08:00
algo_log(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
2026-06-06 09:16:49 +08:00
return
2026-06-06 15:53:50 +08:00
# 5. 零拷贝二进制解析 + 维度转换
2026-06-06 09:16:49 +08:00
data_np = np.frombuffer(data_bytes, dtype=np.float32)
2026-06-06 14:40:07 +08:00
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
2026-06-06 09:16:49 +08:00
data_np = data_np.T.astype(np.float64)
2026-06-07 11:05:24 +08:00
# 6. 写入滤波缓冲区
2026-06-06 09:16:49 +08:00
self.filterBuffer.appendBuffer(data_np)
2026-06-07 11:05:24 +08:00
# 7. 写入范式缓冲区
try:
with self.paradigmBufferLock:
if self.interval_inited:
self.epoch_finished = self.detect_event(data_np)
if self.pack_contain_event:
self.paradigmBuffer.resetAllPara() # 检测到当前pack含有event清除ringbuffer中之前的数据
self.paradigmBuffer.appendBuffer(data_np)
if self.epoch_finished:
time.sleep(0.005)
algo_log('epoch_finished: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG")
else:
self.paradigmBuffer.appendBuffer(data_np)
except Exception as e:
print("锁:写入异常",e)
self.paradigmBuffer.appendBuffer(data_np)
2026-06-06 09:16:49 +08:00
2026-06-06 17:08:09 +08:00
# algo_log(f"数据写入成功shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG")
2026-06-06 09:16:49 +08:00
except Exception as e:
algo_log(f"数据处理失败:{str(e)}", level="ERROR")
2026-06-06 15:53:50 +08:00
if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG':
import traceback
traceback.print_exc()
2026-06-07 11:05:24 +08:00
# 检测是否含有标签
def detect_event(self, samples):
self.pack_contain_event = False
events = np.array(samples[-2])[0].tolist()
for idx, event in enumerate(events):
if int(event) in self.events:
new_key = "".join(
[
str(event),
datetime.datetime.now().strftime("%Y-%m-%d \
-%H-%M-%S"),
]
)
if event == self.predict_event:
self.count_events[new_key] = self.latency + 1
else:
self.count_events[new_key] = self.train_latency + 1
self.event_inner_idx = idx
self.pack_contain_event = True
drop_items = []
for key, value in self.count_events.items():
value = value - 1
if value == 0:
drop_items.append(key)
self.count_events[key] = value
for key in drop_items:
del self.count_events[key]
if drop_items:
return True
return False
2026-06-06 09:16:49 +08:00
def _process_send_queue(self):
"""处理发送队列,向所有命令客户端广播消息"""
while not self.send_queue.empty():
method, params = self.send_queue.get()
if self.cmd_clients:
try:
msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8')
# 打印日志(隐藏大尺寸数据)
if method in ['single_trial_plot', 'miReport']:
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
else:
print(f"Sending CMD message: {msg}")
# 广播到所有命令客户端
for client_id in list(self.cmd_clients):
try:
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
except Exception as e:
print(f"Error sending to CMD client {client_id}: {e}")
self.cmd_clients.discard(client_id) # 移除失效客户端
except Exception as e:
print(f"Error preparing broadcast: {e}")
2026-06-05 09:34:29 +08:00
def run(self):
self.running = True
2026-06-06 15:13:23 +08:00
algo_log(f"algo ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}", level="INFO")
2026-06-06 09:16:49 +08:00
2026-06-05 09:34:29 +08:00
try:
while self.running:
2026-06-06 09:16:49 +08:00
# 1. 处理发送队列(命令端口广播)
self._process_send_queue()
2026-06-06 15:13:23 +08:00
# 2. 轮训监听两个Socket的输入事件
2026-06-06 14:40:07 +08:00
socks = dict(self.poller.poll(50))
2026-06-06 09:16:49 +08:00
# 处理命令端口消息
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
frames = self.cmd_socket.recv_multipart()
self._handle_cmd_message(frames)
# 处理数据端口消息
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
frames = self.data_socket.recv_multipart()
self._handle_data_message(frames)
2026-06-05 09:34:29 +08:00
except Exception as e:
2026-06-06 09:16:49 +08:00
print(f"Server error occurred: {e}")
2026-06-05 09:34:29 +08:00
finally:
self.running = False
2026-06-06 09:16:49 +08:00
# 关闭所有Socket和上下文
self.cmd_socket.close()
self.data_socket.close()
2026-06-05 09:34:29 +08:00
self.context.term()
2026-06-06 09:16:49 +08:00
print("Server sockets and context closed.")
2026-06-05 09:34:29 +08:00
def stop(self):
"""显式关闭服务器"""
self.running = False
2026-06-06 09:16:49 +08:00
self.cmd_socket.close()
self.data_socket.close()
2026-06-05 09:34:29 +08:00
self.context.term()
2026-06-06 09:16:49 +08:00
print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
2026-06-05 09:34:29 +08:00
if __name__ == '__main__':
2026-06-06 09:16:49 +08:00
# 初始化并启动服务器默认cmd=8099, data=8100
2026-06-05 09:34:29 +08:00
server = zmqServer()
2026-06-06 09:16:49 +08:00
server.start()
# 保持主线程运行
try:
while server.running:
threading.Event().wait(1)
except KeyboardInterrupt:
print("Received KeyboardInterrupt, stopping server...")
server.stop()