dev 1
This commit is contained in:
328
Zmq/zmqServer.py
328
Zmq/zmqServer.py
@@ -3,147 +3,257 @@ import zmq
|
||||
import threading
|
||||
import json
|
||||
import queue
|
||||
from Device.SunnyLinker import SunnyLinker64
|
||||
# from Device.SunnyLinker import SunnyLinker64
|
||||
from dataBuffer import ParadigmRingBuffer
|
||||
from filterProcess import FilterRingBuffer
|
||||
from logs.log import algo_log
|
||||
|
||||
class zmqServer(threading.Thread):
|
||||
def __init__(self, host='0.0.0.0', port=8099):
|
||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
||||
threading.Thread.__init__(self)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.cmd_port = cmd_port # 命令交互端口
|
||||
self.data_port = data_port # 数据接收端口
|
||||
self.running = False
|
||||
self.get_Impedance = False # 是否返回阻抗值
|
||||
self.open_Impedance = None # 是否开启阻抗检测功能
|
||||
self.StartDecode = False # false 停止解码,true=开始解码
|
||||
self.StartTrain = False # False未进入训练状态,True处于训练状态
|
||||
self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态
|
||||
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
|
||||
self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。
|
||||
self.getReport = False # 获取训练报告内容
|
||||
|
||||
# 原有业务状态变量
|
||||
# self.get_Impedance = False # 是否返回阻抗值
|
||||
# self.open_Impedance = None # 是否开启阻抗检测功能
|
||||
self.StartDecode = False # false 停止解码,true=开始解码
|
||||
self.StartTrain = False # False未进入训练状态,True处于训练状态
|
||||
self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态
|
||||
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
|
||||
self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。
|
||||
# self.getReport = False # 获取训练报告内容
|
||||
self.daemon = True
|
||||
# 创建 ZeroMQ 上下文
|
||||
|
||||
# 范式数据缓存
|
||||
self.paradigmBuffer = ParadigmRingBuffer(66, 2500)
|
||||
self.filterBuffer = FilterRingBuffer(66, 2500)
|
||||
|
||||
|
||||
# 命令与数据通信
|
||||
self.context = zmq.Context()
|
||||
# 创建 REP 套接字(响应端)
|
||||
self.socket = self.context.socket(zmq.ROUTER)
|
||||
self.socket.bind(f"tcp://{self.host}:{self.port}") # 绑定到端口 8099
|
||||
# 指令通道 (8099) - ROUTER:短JSON命令,低频率
|
||||
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
||||
self.cmd_socket.setsockopt(zmq.RCVHWM, 100) # 指令不需要大缓存,100条足够
|
||||
self.cmd_socket.setsockopt(zmq.SNDHWM, 100)
|
||||
self.cmd_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,降低指令延迟
|
||||
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
||||
|
||||
# 数据通道 (8100) - ROUTER:高频脑电二进制流
|
||||
self.data_socket = self.context.socket(zmq.ROUTER)
|
||||
self.data_socket.setsockopt(zmq.RCVHWM, 500) # 500包=10秒缓存,足够应对短时卡顿
|
||||
self.data_socket.setsockopt(zmq.TCP_NODELAY, 1) # 禁用Nagle算法,减少数据传输延迟
|
||||
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
||||
|
||||
# Poller 轮训器(保持不变)
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.cmd_socket, zmq.POLLIN)
|
||||
self.poller.register(self.data_socket, zmq.POLLIN)
|
||||
|
||||
# 业务变量
|
||||
self.targetFreqs = []
|
||||
self.changeTarget = False # 更换目标频率
|
||||
self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化
|
||||
# self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化
|
||||
self.labels = [0x01, 0x02,0x03]
|
||||
|
||||
self.decoder_switch = False #更换解码器
|
||||
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
|
||||
# Client Management (e.g. Unity, Other listeners)
|
||||
self.clients = set() # 维护客户端ID
|
||||
self.send_queue = queue.Queue() # 发送队列,安全信箱,维护socket线程
|
||||
|
||||
# 客户端管理 - 区分命令/数据客户端
|
||||
self.cmd_clients = set() # 命令端口客户端ID
|
||||
self.data_clients = set() # 数据端口客户端ID
|
||||
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
|
||||
|
||||
def broadcast_message(self, method, params):
|
||||
"""Put message into queue to be sent to all connected clients"""
|
||||
"""Put message into queue to be sent to all command clients"""
|
||||
self.send_queue.put((method, params))
|
||||
|
||||
def _handle_cmd_message(self, frames):
|
||||
"""处理命令端口消息(原有命令交互逻辑)"""
|
||||
if len(frames) < 3:
|
||||
return
|
||||
ident, _, message_bytes = frames[:3]
|
||||
|
||||
# 注册新的命令客户端
|
||||
if ident not in self.cmd_clients:
|
||||
self.cmd_clients.add(ident)
|
||||
print(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
|
||||
|
||||
# 解析消息
|
||||
try:
|
||||
message = json.loads(message_bytes.decode('utf-8'))
|
||||
except json.JSONDecodeError:
|
||||
print(f"Invalid JSON from CMD client {ident}")
|
||||
continue
|
||||
print(f"Received CMD request: {message}")
|
||||
|
||||
method = message.get("method")
|
||||
params = message.get("params")
|
||||
|
||||
# 原有命令处理逻辑
|
||||
if method == "sync":
|
||||
self.state_mode = 'sync'
|
||||
if method == "targetFreqs":
|
||||
if not isinstance(params, list):
|
||||
print('targetFreqs must be a list')
|
||||
continue
|
||||
if params != self.targetFreqs:
|
||||
self.targetFreqs = params
|
||||
self.changeTarget = True
|
||||
if method == "decoderClass":
|
||||
if not isinstance(params, str):
|
||||
print('decoderClass must be a str')
|
||||
continue
|
||||
if params != self.decoder_class:
|
||||
self.decoder_class = params
|
||||
self.decoder_switch = True
|
||||
if method == "getReport":
|
||||
self.getReport = True
|
||||
if method == "train":#训练状态
|
||||
self.state_mode = 'train'
|
||||
self.StartTrain = True
|
||||
self.currentLabel = params # 当前刺激端的训练标签
|
||||
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||
elif method == "predict":#预测状态
|
||||
self.state_mode = 'predict'
|
||||
if params == 1: #开始解码
|
||||
self.StartDecode = True
|
||||
self.sunnyLinker.push_trigger(0x63)
|
||||
elif params == 2: #停止解码
|
||||
self.IsExitApp = True
|
||||
self.running = False
|
||||
elif method == "rest": #休息状态
|
||||
self.state_mode = 'rest'
|
||||
# elif method == "impedance":
|
||||
# if params == 1:
|
||||
# self.open_Impedance = True # 开启阻抗
|
||||
# self.get_Impedance = True # 返回阻抗
|
||||
# elif params == 2:
|
||||
# self.open_Impedance = False # 关闭阻抗
|
||||
# self.get_Impedance = False # 停止返回阻抗
|
||||
|
||||
def _handle_data_message(self, frames):
|
||||
"""
|
||||
处理8100端口原始脑电二进制数据
|
||||
固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区
|
||||
"""
|
||||
# 1. 校验ZMQ消息帧完整性
|
||||
if len(frames) < 3:
|
||||
print(f"[ERROR] 无效数据帧:长度不足3帧,实际长度={len(frames)}")
|
||||
return
|
||||
|
||||
ident, _, data_bytes = frames[:3]
|
||||
|
||||
# 2. 客户端管理(单客户端场景,自动更新最新身份)
|
||||
if ident not in self.data_clients:
|
||||
self.data_clients.add(ident)
|
||||
self.current_data_client = ident # 保存唯一客户端身份,用于后续回复滤波结果
|
||||
print(f"[INFO] 新数据客户端连接成功:{ident}")
|
||||
|
||||
try:
|
||||
# 3. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节,与int32字节数相同)
|
||||
EXPECTED_BYTES = 5 * 66 * 4 # 每个float32占4字节
|
||||
if len(data_bytes) != EXPECTED_BYTES:
|
||||
print(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节")
|
||||
return
|
||||
|
||||
# 4. 零拷贝二进制解析 + 维度转换
|
||||
# 步骤:字节流 → (330,) float32数组 → (5,66) 原始格式 → 转置为 (66,5) 缓冲区标准格式
|
||||
data_np = np.frombuffer(data_bytes, dtype=np.float32)
|
||||
# 重塑为上位机原始维度
|
||||
data_np = data_np.reshape(5, 66)
|
||||
# 转置为(通道数, 采样点数)标准格式,转换为float64保证滤波运算精度
|
||||
data_np = data_np.T.astype(np.float64)
|
||||
|
||||
# 5. 同时写入双环形缓冲区(方法名与现有类保持一致:appendBuffer)
|
||||
# 注意:上位机已发送微伏物理量,无需再乘以增益系数
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
self.filterBuffer.appendBuffer(data_np)
|
||||
|
||||
# 生产环境必须注释!每秒50次打印会导致CPU占用飙升30%以上
|
||||
algo_log(f"数据写入成功:shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG", record_once=True)
|
||||
|
||||
except Exception as e:
|
||||
algo_log(f"数据处理失败:{str(e)}", level="ERROR")
|
||||
# 调试阶段临时打开,生产环境务必注释
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def _process_send_queue(self):
|
||||
"""处理发送队列,向所有命令客户端广播消息"""
|
||||
while not self.send_queue.empty():
|
||||
method, params = self.send_queue.get()
|
||||
if self.cmd_clients:
|
||||
try:
|
||||
msg = {'method': method, 'params': params}
|
||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||
|
||||
# 打印日志(隐藏大尺寸数据)
|
||||
if method in ['single_trial_plot', 'miReport']:
|
||||
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
|
||||
else:
|
||||
print(f"Sending CMD message: {msg}")
|
||||
|
||||
# 广播到所有命令客户端
|
||||
for client_id in list(self.cmd_clients):
|
||||
try:
|
||||
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
|
||||
except Exception as e:
|
||||
print(f"Error sending to CMD client {client_id}: {e}")
|
||||
self.cmd_clients.discard(client_id) # 移除失效客户端
|
||||
except Exception as e:
|
||||
print(f"Error preparing broadcast: {e}")
|
||||
|
||||
def run(self):
|
||||
self.running = True
|
||||
print(f"Server is running on {self.host}:{self.port}")
|
||||
# Use Poller for non-blocking receive
|
||||
poller = zmq.Poller()
|
||||
poller.register(self.socket, zmq.POLLIN)
|
||||
print(f"ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
# 1. Process Send Queue (Send to all clients)
|
||||
while not self.send_queue.empty():
|
||||
method, params = self.send_queue.get()
|
||||
if self.clients:
|
||||
try:
|
||||
msg = {'method': method, 'params': params}
|
||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||
if method in ['single_trial_plot', 'single_trial_plot', 'miReport']:
|
||||
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
|
||||
else:
|
||||
print(f"Sending message: {msg}")
|
||||
# Broadcast to all maintained clients
|
||||
for client_id in list(self.clients):
|
||||
try:
|
||||
# Send: [ID, Empty, JSON]
|
||||
self.socket.send_multipart([client_id, b'', msg_bytes])
|
||||
except Exception as e:
|
||||
print(f"Error sending to client {client_id}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error preparing broadcast: {e}")
|
||||
# 1. 处理发送队列(命令端口广播)
|
||||
self._process_send_queue()
|
||||
|
||||
# 2. Process Receive (Commands)
|
||||
socks = dict(poller.poll(10)) # 100ms timeout
|
||||
if self.socket in socks and socks[self.socket] == zmq.POLLIN:
|
||||
frames = self.socket.recv_multipart()
|
||||
if len(frames) < 3:
|
||||
continue
|
||||
ident, _, message_bytes = frames[:3]
|
||||
if ident not in self.clients: # register client ID
|
||||
self.clients.add(ident)
|
||||
print(f"New Client Detected: {ident}")
|
||||
try:
|
||||
message = json.loads(message_bytes.decode('utf-8'))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
print(f"Received request: {message}")
|
||||
# 2. 轮训监听两个Socket的输入事件(10ms超时,避免阻塞)
|
||||
socks = dict(self.poller.poll(10))
|
||||
|
||||
method = message.get("method") # process request
|
||||
params = message.get("params")
|
||||
# 处理命令端口消息
|
||||
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
||||
frames = self.cmd_socket.recv_multipart()
|
||||
self._handle_cmd_message(frames)
|
||||
|
||||
if method == "sync":
|
||||
self.state_mode = 'sync'
|
||||
if method == "targetFreqs":
|
||||
if not isinstance(params,list):
|
||||
print('targetFreqs must be a list')
|
||||
continue
|
||||
if params != self.targetFreqs:
|
||||
self.targetFreqs = params
|
||||
self.changeTarget = True
|
||||
if method == "decoderClass":
|
||||
if not isinstance(params,str):
|
||||
print('decoderClass must be a str')
|
||||
continue
|
||||
if params != self.decoder_class:
|
||||
self.decoder_class = params
|
||||
self.decoder_switch = True
|
||||
if method == "getReport":
|
||||
self.getReport = True
|
||||
if method == "train":#训练状态
|
||||
self.state_mode = 'train'
|
||||
self.StartTrain = True
|
||||
self.currentLabel = params # 当前刺激端的训练标签
|
||||
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||
elif method == "predict":#预测状态
|
||||
self.state_mode = 'predict'
|
||||
if params == 1: #开始解码
|
||||
self.StartDecode = True
|
||||
self.sunnyLinker.push_trigger(0x63)
|
||||
elif params == 2: #停止解码
|
||||
self.IsExitApp = True
|
||||
self.running = False
|
||||
elif method == "rest": #休息状态
|
||||
self.state_mode = 'rest'
|
||||
elif method == "impedance":
|
||||
if params == 1:
|
||||
self.open_Impedance = True # 开启阻抗
|
||||
self.get_Impedance = True # 返回阻抗
|
||||
elif params == 2:
|
||||
self.open_Impedance = False # 关闭阻抗
|
||||
self.get_Impedance = False # 停止返回阻抗
|
||||
# 处理数据端口消息
|
||||
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
||||
frames = self.data_socket.recv_multipart()
|
||||
self._handle_data_message(frames)
|
||||
|
||||
except Exception as e:
|
||||
print(f"An socket error occurred: {e}")
|
||||
print(f"Server error occurred: {e}")
|
||||
finally:
|
||||
self.running = False
|
||||
# 关闭套接字和上下文
|
||||
self.socket.close()
|
||||
# 关闭所有Socket和上下文
|
||||
self.cmd_socket.close()
|
||||
self.data_socket.close()
|
||||
self.context.term()
|
||||
print("Server socket and context closed.")
|
||||
print("Server sockets and context closed.")
|
||||
|
||||
def stop(self):
|
||||
"""显式关闭服务器"""
|
||||
self.running = False
|
||||
self.socket.close()
|
||||
self.cmd_socket.close()
|
||||
self.data_socket.close()
|
||||
self.context.term()
|
||||
print("Server closed explicitly.")
|
||||
print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 初始化并启动服务器(默认cmd=8099, data=8100)
|
||||
server = zmqServer()
|
||||
server.start()
|
||||
server.start()
|
||||
|
||||
# 保持主线程运行
|
||||
try:
|
||||
while server.running:
|
||||
threading.Event().wait(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Received KeyboardInterrupt, stopping server...")
|
||||
server.stop()
|
||||
|
||||
Reference in New Issue
Block a user