Files
bci_algo/Zmq/zmqServer.py
2026-06-06 09:16:49 +08:00

260 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import zmq
import threading
import json
import queue
# from Device.SunnyLinker import SunnyLinker64
from dataBuffer import ParadigmRingBuffer
from filterProcess import FilterRingBuffer
from logs.log import algo_log
class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
threading.Thread.__init__(self)
self.host = host
self.cmd_port = cmd_port # 命令交互端口
self.data_port = data_port # 数据接收端口
self.running = False
# 原有业务状态变量
# self.get_Impedance = False # 是否返回阻抗值
# self.open_Impedance = None # 是否开启阻抗检测功能
self.StartDecode = False # false 停止解码true=开始解码
self.StartTrain = False # False未进入训练状态True处于训练状态
self.state_mode = None # 'train'为训练状态rest'为休息状态,'test'为测试状态
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
self.IsExitApp = False # 当socket收到2的时候就置为True代表要退出系统了。
# self.getReport = False # 获取训练报告内容
self.daemon = True
# 范式数据缓存
self.paradigmBuffer = ParadigmRingBuffer(66, 2500)
self.filterBuffer = FilterRingBuffer(66, 2500)
# 命令与数据通信
self.context = zmq.Context()
# 指令通道 (8099) - ROUTER短JSON命令低频率
self.cmd_socket = self.context.socket(zmq.ROUTER)
self.cmd_socket.setsockopt(zmq.RCVHWM, 100) # 指令不需要大缓存100条足够
self.cmd_socket.setsockopt(zmq.SNDHWM, 100)
self.cmd_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法降低指令延迟
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
# 数据通道 (8100) - ROUTER高频脑电二进制流
self.data_socket = self.context.socket(zmq.ROUTER)
self.data_socket.setsockopt(zmq.RCVHWM, 500) # 500包=10秒缓存足够应对短时卡顿
self.data_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法减少数据传输延迟
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
# Poller 轮训器(保持不变)
self.poller = zmq.Poller()
self.poller.register(self.cmd_socket, zmq.POLLIN)
self.poller.register(self.data_socket, zmq.POLLIN)
# 业务变量
self.targetFreqs = []
self.changeTarget = False # 更换目标频率
# self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类已在Decoder实例化
self.labels = [0x01, 0x02,0x03]
self.decoder_switch = False #更换解码器
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
# 客户端管理 - 区分命令/数据客户端
self.cmd_clients = set() # 命令端口客户端ID
self.data_clients = set() # 数据端口客户端ID
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
def broadcast_message(self, method, params):
"""Put message into queue to be sent to all command clients"""
self.send_queue.put((method, params))
def _handle_cmd_message(self, frames):
"""处理命令端口消息(原有命令交互逻辑)"""
if len(frames) < 3:
return
ident, _, message_bytes = frames[:3]
# 注册新的命令客户端
if ident not in self.cmd_clients:
self.cmd_clients.add(ident)
print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
# 解析消息
try:
message = json.loads(message_bytes.decode('utf-8'))
except json.JSONDecodeError:
print(f"Invalid JSON from CMD client {ident}")
continue
print(f"Received CMD request: {message}")
method = message.get("method")
params = message.get("params")
# 原有命令处理逻辑
if method == "sync":
self.state_mode = 'sync'
if method == "targetFreqs":
if not isinstance(params, list):
print('targetFreqs must be a list')
continue
if params != self.targetFreqs:
self.targetFreqs = params
self.changeTarget = True
if method == "decoderClass":
if not isinstance(params, str):
print('decoderClass must be a str')
continue
if params != self.decoder_class:
self.decoder_class = params
self.decoder_switch = True
if method == "getReport":
self.getReport = True
if method == "train":#训练状态
self.state_mode = 'train'
self.StartTrain = True
self.currentLabel = params # 当前刺激端的训练标签
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
elif method == "predict":#预测状态
self.state_mode = 'predict'
if params == 1: #开始解码
self.StartDecode = True
self.sunnyLinker.push_trigger(0x63)
elif params == 2: #停止解码
self.IsExitApp = True
self.running = False
elif method == "rest": #休息状态
self.state_mode = 'rest'
# elif method == "impedance":
# if params == 1:
# self.open_Impedance = True # 开启阻抗
# self.get_Impedance = True # 返回阻抗
# elif params == 2:
# self.open_Impedance = False # 关闭阻抗
# self.get_Impedance = False # 停止返回阻抗
def _handle_data_message(self, frames):
"""
处理8100端口原始脑电二进制数据
固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区
"""
# 1. 校验ZMQ消息帧完整性
if len(frames) < 3:
print(f"[ERROR] 无效数据帧长度不足3帧实际长度={len(frames)}")
return
ident, _, data_bytes = frames[:3]
# 2. 客户端管理(单客户端场景,自动更新最新身份)
if ident not in self.data_clients:
self.data_clients.add(ident)
self.current_data_client = ident # 保存唯一客户端身份,用于后续回复滤波结果
print(f"[INFO] 新数据客户端连接成功:{ident}")
try:
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节与int32字节数相同
EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节
if len(data_bytes) != EXPECTED_BYTES:
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节")
return
# 4. 零拷贝二进制解析 + 维度转换
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
data_np = np.frombuffer(data_bytes, dtype=np.float32)
# 重塑为上位机原始维度
data_np = data_np.reshape(5, 66)
# 转置为(通道数, 采样点数)标准格式转换为float64保证滤波运算精度
data_np = data_np.T.astype(np.float64)
# 5. 同时写入双环形缓冲区方法名与现有类保持一致appendBuffer
# 注意:上位机已发送微伏物理量,无需再乘以增益系数
self.paradigmBuffer.appendBuffer(data_np)
self.filterBuffer.appendBuffer(data_np)
# 生产环境必须注释每秒50次打印会导致CPU占用飙升30%以上
algo_log(f"数据写入成功shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG", record_once=True)
except Exception as e:
algo_log(f"数据处理失败:{str(e)}", level="ERROR")
# 调试阶段临时打开,生产环境务必注释
import traceback
traceback.print_exc()
def _process_send_queue(self):
"""处理发送队列,向所有命令客户端广播消息"""
while not self.send_queue.empty():
method, params = self.send_queue.get()
if self.cmd_clients:
try:
msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8')
# 打印日志(隐藏大尺寸数据)
if method in ['single_trial_plot', 'miReport']:
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
else:
print(f"Sending CMD message: {msg}")
# 广播到所有命令客户端
for client_id in list(self.cmd_clients):
try:
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
except Exception as e:
print(f"Error sending to CMD client {client_id}: {e}")
self.cmd_clients.discard(client_id) # 移除失效客户端
except Exception as e:
print(f"Error preparing broadcast: {e}")
def run(self):
self.running = True
print(f"ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
try:
while self.running:
# 1. 处理发送队列(命令端口广播)
self._process_send_queue()
# 2. 轮训监听两个Socket的输入事件10ms超时避免阻塞
socks = dict(self.poller.poll(10))
# 处理命令端口消息
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
frames = self.cmd_socket.recv_multipart()
self._handle_cmd_message(frames)
# 处理数据端口消息
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
frames = self.data_socket.recv_multipart()
self._handle_data_message(frames)
except Exception as e:
print(f"Server error occurred: {e}")
finally:
self.running = False
# 关闭所有Socket和上下文
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
print("Server sockets and context closed.")
def stop(self):
"""显式关闭服务器"""
self.running = False
self.cmd_socket.close()
self.data_socket.close()
self.context.term()
print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
if __name__ == '__main__':
# 初始化并启动服务器默认cmd=8099, data=8100
server = zmqServer()
server.start()
# 保持主线程运行
try:
while server.running:
threading.Event().wait(1)
except KeyboardInterrupt:
print("Received KeyboardInterrupt, stopping server...")
server.stop()