import numpy as np import zmq import threading import json import queue import time from Device.SunnyLinker import SunnyLinker64, RingBuffer from collections import deque class zmqServer(threading.Thread): def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100): threading.Thread.__init__(self) self.host = host self.cmd_port = cmd_port self.data_port = data_port self.running = False self.get_Impedance = False self.open_Impedance = None self.StartDecode = False self.StartTrain = False self.state_mode = None self.currentLabel = -1 self.IsExitApp = False self.getReport = False self.daemon = True # ZMQ Context self.context = zmq.Context() # 指令通道 (8099) - ROUTER self.cmd_socket = self.context.socket(zmq.ROUTER) self.cmd_socket.setsockopt(zmq.RCVHWM, 1000) self.cmd_socket.setsockopt(zmq.SNDHWM, 1000) self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}") # 数据通道 (8100)) - ROUTER self.data_socket = self.context.socket(zmq.ROUTER) self.data_socket.setsockopt(zmq.RCVHWM, 1000) self.data_socket.setsockopt(zmq.RCVTIMEO, 50) self.data_socket.bind(f"tcp://{self.host}:{data_port}") self.targetFreqs = [] self.changeTarget = False self.sunnyLinker = SunnyLinker64(None, None, None, None, None) self.labels = [0x01, 0x02, 0x03] self.decoder_switch = False self.decoder_class = None self.cmd_clients = set() self.data_clients = set() self.send_queue = queue.Queue() # ========== 数据缓冲区 (RingBuffer) ========== # 与 SunnyLinker 保持一致,使用 RingBuffer # 66 = 64 EEG通道 + 1 事件通道(第65) + 1 标签序号通道(第66) # 缓存约 10 秒数据 (250Hz * 10s = 2500 点) self.n_chan = 66 self.t_buffer = 10.0 # 缓冲区时长(秒) self.__ringBuffer = RingBuffer(self.n_chan, int(self.t_buffer * 250)) # 事件检测相关 self._event_lock = threading.Lock() self._epoch_finished = False self._event_inner_idx = -1 self.pack_contain_event = False self.predict_event = 99 self.events = [1, 2, self.predict_event] self.count_events = {} self.latency = 50 self.train_latency = 50 # 当前事件标签序号 (从第66通道获取) self.current_label_index = 0 # 初始化标志 self._interval_inited = False self._currentLabel = -1 # 注册的客户端(兼容旧接口) self.clients = set() # ========== 事件属性:线程安全访问 ========== @property def epoch_finished(self): with self._event_lock: return self._epoch_finished @epoch_finished.setter def epoch_finished(self, value): with self._event_lock: self._epoch_finished = value @property def event_inner_idx(self): with self._event_lock: return self._event_inner_idx @event_inner_idx.setter def event_inner_idx(self, value): with self._event_lock: self._event_inner_idx = value @property def interval_inited(self): return self._interval_inited @interval_inited.setter def interval_inited(self, value): self._interval_inited = value @property def currentLabel(self): return self._currentLabel @currentLabel.setter def currentLabel(self, value): self._currentLabel = value def broadcast_message(self, method, params): """Put message into queue to be sent to all connected clients""" self.send_queue.put((method, params)) # ========== 数据缓冲区操作接口 ========== def GetDataLenCount(self): """返回缓冲区当前数据点数""" return self.__ringBuffer.nUpdate def getData(self, count): """获取最新count个数据点,不消费(只读)""" with self.__ringBuffer.RingBufferLock: count = min(count, self.__ringBuffer.nUpdate) if count == 0: return np.zeros((self.n_chan, 0)) # 计算读取范围(从尾部取最新数据) read_end = (self.__ringBuffer.currentPtr - 1) % self.__ringBuffer.n_points read_start = (read_end - count + 1) % self.__ringBuffer.n_points if self.__ringBuffer.currentPtr == 0: read_start = self.__ringBuffer.n_points - count read_end = self.__ringBuffer.n_points - 1 if read_start <= read_end: data = self.__ringBuffer.buffer[:, read_start:read_end + 1] else: part1 = self.__ringBuffer.buffer[:, read_start:] part2 = self.__ringBuffer.buffer[:, :read_end + 1] data = np.concatenate((part1, part2), axis=1) return data def consumeData(self, count): """消费(丢弃)指定数量的数据点,从头部移除""" with self.__ringBuffer.RingBufferLock: count = min(count, self.__ringBuffer.nUpdate) self.__ringBuffer.readPtr = (self.__ringBuffer.readPtr + count) % self.__ringBuffer.n_points self.__ringBuffer.nUpdate -= count def ResetAll(self): """重置缓冲区""" with self.__ringBuffer.RingBufferLock: self.__ringBuffer.resetAllPara() with self._event_lock: self._epoch_finished = False self._event_inner_idx = -1 self.pack_contain_event = False self.count_events.clear() self.current_label_index = 0 def reset_data_buffer(self): self.ResetAll() def reset_state(self): self.ResetAll() def interval_init(self, decoder_class): """初始化事件检测参数""" import ast from PubLibrary.InifileHelper import IniRead if decoder_class == 'ssmvep': interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) self.interval_epoch = [int(i * 250) for i in interval_epoch] self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * 250)] self.latency = (self.interval_epoch[1] + 0.1 * 250) // 5 self.train_latency = (self.train_epoch[1] + 0.1 * 250) // 5 elif decoder_class == 'mi': interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) self.interval_epoch = [int(i * 250) for i in interval_epoch] self.train_epoch = self.interval_epoch.copy() self.latency = self.interval_epoch[1] // 5 self.train_latency = self.latency self.count_events = {} self._event_inner_idx = -1 self._epoch_finished = False self.pack_contain_event = False self.predict_event = 99 self.events = [1, 2, self.predict_event] self._interval_inited = True # ========== 事件检测 ========== def detect_event(self, data_matrix): """ 检测事件通道中的触发信号 @param data_matrix: shape (66, N) - N个采样点的数据 第65行(索引64) = 事件通道 第66行(索引65) = 标签通道 @return: 是否检测到事件 """ if data_matrix.shape[1] == 0: return False self.pack_contain_event = False event_channel = data_matrix[64, :] # 第65通道 = 标签值(event值) label_channel = data_matrix[65, :] # 第66通道 = 标签序号(label index) events = event_channel.tolist() with self._event_lock: self._event_inner_idx = -1 self.current_event_label = 0 for idx, event in enumerate(events): if int(event) in self.events: self._event_inner_idx = idx self.current_label_index = int(label_channel[idx]) self.pack_contain_event = True new_key = f"{event}_{time.time()}" latency = self.latency if event == self.predict_event else self.train_latency self.count_events[new_key] = latency + 1 # 延迟计数递减 drop_items = [] for key, value in self.count_events.items(): value = value - 1 if value == 0: drop_items.append(key) self.count_events[key] = value for key in drop_items: del self.count_events[key] if drop_items: self._epoch_finished = True # 检测到事件时,清除RingBuffer中之前的数据,只保留当前包 if self.pack_contain_event: self.__ringBuffer.resetAllPara() return True self._epoch_finished = False return False def run(self): self.running = True print(f"Server running - CMD: {self.cmd_port}, DATA: {self.data_port}") cmd_poller = zmq.Poller() cmd_poller.register(self.cmd_socket, zmq.POLLIN) data_poller = zmq.Poller() data_poller.register(self.data_socket, zmq.POLLIN) try: while self.running: # --- 处理发送队列 (指令通道) --- while not self.send_queue.empty(): method, params = self.send_queue.get() if self.cmd_clients: try: msg = {'method': method, 'params': params} msg_bytes = json.dumps(msg).encode('utf-8') for client_id in list(self.cmd_clients): try: self.cmd_socket.send_multipart([client_id, b'', msg_bytes]) except Exception: pass except Exception: pass # --- 处理指令通道 --- socks = dict(cmd_poller.poll(10)) if self.cmd_socket in socks: self._handle_cmd_socket() # --- 处理数据通道 --- socks = dict(data_poller.poll(10)) if self.data_socket in socks: self._handle_data_socket() except Exception as e: print(f"Server error: {e}") finally: self.running = False self.cmd_socket.close() self.data_socket.close() self.context.term() def _handle_cmd_socket(self): """处理指令通道消息""" try: frames = self.cmd_socket.recv_multipart() if len(frames) < 3: return ident, _, message_bytes = frames[:3] self.cmd_clients.add(ident) self.clients.add(ident) message = json.loads(message_bytes.decode('utf-8')) method = message.get("method") params = message.get("params") print(f"[CMD] {method}: {params}") if method == "sync": self.state_mode = 'sync' elif method == "targetFreqs": if isinstance(params, list) and params != self.targetFreqs: self.targetFreqs = params self.changeTarget = True elif method == "decoderClass": if isinstance(params, str) and params != self.decoder_class: self.decoder_class = params self.decoder_switch = True elif method == "getReport": self.getReport = True elif method == "train": self.state_mode = 'train' self.StartTrain = True self.currentLabel = params elif method == "predict": self.state_mode = 'predict' if params == 1: self.StartDecode = True elif params == 2: self.IsExitApp = True self.running = False elif method == "rest": self.state_mode = 'rest' elif method == "impedance": if params == 1: self.open_Impedance = True self.get_Impedance = True elif params == 2: self.open_Impedance = False self.get_Impedance = False except Exception as e: print(f"CMD socket error: {e}") def _handle_data_socket(self): """处理数据通道消息 (EEG数据) 上位机数据格式: - 数据帧: [identity, '', meta_json, data_buffer] data_buffer = [N, 66] float32 -> 转置为 [66, N] """ try: frames = self.data_socket.recv_multipart() if len(frames) < 4: return ident, _, message_bytes = frames[:3] self.data_clients.add(ident) meta = json.loads(message_bytes.decode('utf-8')) # data: [N, 66] -> 转置 -> [66, N] raw_data = np.frombuffer(frames[3], dtype=np.float32) n_samples, n_channels = meta.get('shape', [5, 66]) data_matrix = raw_data.reshape(n_samples, n_channels).T.astype(np.float32) # 写入 RingBuffer with self.__ringBuffer.RingBufferLock: self.__ringBuffer.appendBuffer(data_matrix) # 事件检测 self.detect_event(data_matrix) except Exception as e: print(f"DATA socket error: {e}") # ========== 各范式数据访问接口 ========== def get_MIData(self): """获取MI导联数据 (21通道 + 事件)""" data = self.getData(self.GetDataLenCount()) rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65] row_to_select = np.array(rows_to_extract) if data.shape[1] > 0: return data[row_to_select, :] return np.zeros((len(rows_to_extract), 0)) def get_SSMVEPData(self): """获取SSMVEP导联数据 (8通道 + 事件)""" data = self.getData(self.GetDataLenCount()) rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65] row_to_select = np.array(rows_to_extract) if data.shape[1] > 0: return data[row_to_select, :] return np.zeros((len(rows_to_extract), 0)) def getDataViaSSVEP(self, count): """获取SSVEP数据 (8通道 + 事件)""" data = self.getData(count) rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64] row_to_select = np.array(rows_to_extract) if data.shape[1] > 0: return data[row_to_select, :] return np.zeros((len(rows_to_extract), 0)) def get_concentrateData(self, count): """获取专注力数据 (2通道)""" data = self.getData(count) rows_to_extract = [0, 1] row_to_select = np.array(rows_to_extract) if data.shape[1] > 0: return data[row_to_select, :] return np.zeros((len(rows_to_extract), 0)) def get_blinkData(self, count): """获取眨眼数据 (2通道)""" data = self.getData(count) rows_to_extract = [0, 1] row_to_select = np.array(rows_to_extract) if data.shape[1] > 0: return data[row_to_select, :] return np.zeros((len(rows_to_extract), 0)) def getImpedance(self, data, decoder_class): """计算阻抗(ZMQ模式下不可用)""" return np.zeros(8) def stop(self): self.running = False self.cmd_socket.close() self.data_socket.close() self.context.term() if __name__ == '__main__': server = zmqServer() server.start()