add filter process
This commit is contained in:
59
Decoder.py
59
Decoder.py
@@ -42,49 +42,33 @@ MODEL_FOLDER = "online_Models"
|
|||||||
class Decoder_main(threading.Thread):
|
class Decoder_main(threading.Thread):
|
||||||
def __init__(self, device_info=None):
|
def __init__(self, device_info=None):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.device_info = {
|
self.device_info = device_info
|
||||||
'sample_rate': device_info['sample_rate'],
|
|
||||||
'frame_points': device_info['frame_points'],
|
|
||||||
'channel_nums': device_info['channel_nums'],
|
|
||||||
'channel_names': device_info['channel_names'],
|
|
||||||
'channel_index': device_info['channel_index'],
|
|
||||||
}
|
|
||||||
self.Runing=True
|
self.Runing=True
|
||||||
self.decoder = None
|
self.decoder = None
|
||||||
self.decoder_class = None #解码器类别
|
self.decoder_class = None #解码器类别
|
||||||
|
|
||||||
# 与采集设备通信的状态码,0为异常,1为正常
|
|
||||||
# self.status_code = 0
|
|
||||||
# self.device_info['sample_rate'] = 250 # 采样率
|
|
||||||
# self.energy = 0 # 电量
|
|
||||||
|
|
||||||
|
|
||||||
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
||||||
|
|
||||||
|
|
||||||
|
self.zmqServer = None
|
||||||
|
self.sliding_filter = None
|
||||||
|
|
||||||
|
self._init_threads()
|
||||||
|
|
||||||
|
def _init_threads(self):
|
||||||
|
"""初始化ZMQ服务和滤波线程"""
|
||||||
|
# 1. 初始化ZMQServer并启动
|
||||||
self.zmqServer = zmqServer(device_info=self.device_info)
|
self.zmqServer = zmqServer(device_info=self.device_info)
|
||||||
self.zmqServer.start()
|
self.zmqServer.start() # 启动ZMQ接收线程
|
||||||
|
|
||||||
self.filter = SlidingFilter()
|
# 2. 初始化滤波线程(关联ZMQServer的环形缓存)
|
||||||
|
self.sliding_filter = SlidingFilter(
|
||||||
|
ring_buffer=self.zmqServer.filterBuffer,
|
||||||
|
n_chan=self.zmqServer.device_info['channel_nums'],
|
||||||
|
srate=self.zmqServer.device_info['sample_rate']
|
||||||
|
)
|
||||||
|
|
||||||
# self.zmqClient = zmqClient(_upper_host, _upper_port)
|
# 注册滤波结果回调(示例:打印数据形状)
|
||||||
# self.zmqClient.set_zmq_server(self.zmqServer)
|
self.sliding_filter.set_result_callback(self.zmqServer.send_filtered_data)
|
||||||
# self.zmqClient.connect()
|
|
||||||
|
|
||||||
|
|
||||||
# def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
|
|
||||||
# self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
|
|
||||||
# _device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
|
|
||||||
# _device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
|
|
||||||
# _upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
|
|
||||||
# _upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
|
|
||||||
|
|
||||||
# if self.DeviceType == 1:
|
|
||||||
# self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.device_info['sample_rate'], 64, method='tcp')
|
|
||||||
# self.thread_data_server.host = _device_host
|
|
||||||
# self.thread_data_server.port = _device_port
|
|
||||||
|
|
||||||
# self.thread_data_server.toUv = True
|
|
||||||
# self.thread_data_server.start()
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
||||||
@@ -210,6 +194,10 @@ class Decoder_main(threading.Thread):
|
|||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while self.Runing:
|
while self.Runing:
|
||||||
|
# 当滤波数据大于5秒时,启动滤波线程
|
||||||
|
if self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5:
|
||||||
|
self.sliding_filter.start()
|
||||||
|
|
||||||
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
|
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
|
||||||
print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}")
|
print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}")
|
||||||
self.zmqServer.decoder_switch = False
|
self.zmqServer.decoder_switch = False
|
||||||
@@ -487,6 +475,7 @@ class Decoder_main(threading.Thread):
|
|||||||
@return:
|
@return:
|
||||||
'''
|
'''
|
||||||
self.zmqServer.stop()
|
self.zmqServer.stop()
|
||||||
|
self.sliding_filter.stop()
|
||||||
self.Runing=False
|
self.Runing=False
|
||||||
|
|
||||||
def reset_state(self):
|
def reset_state(self):
|
||||||
|
|||||||
@@ -13,13 +13,6 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and
|
|||||||
6. decoder class切换问题
|
6. decoder class切换问题
|
||||||
7. decoder_class切换时,数据重置、各类参数重置
|
7. decoder_class切换时,数据重置、各类参数重置
|
||||||
|
|
||||||
# update
|
|
||||||
2026年6月5日13:55:34
|
|
||||||
|
|
||||||
# 遗留问题
|
|
||||||
1. 之前当处于阻抗检测状态时,Decoder在空跑。当前无法判断是否处于阻抗检测状态。
|
|
||||||
- 解决方法,保留之前发阻抗命令
|
|
||||||
|
|
||||||
|
|
||||||
# 常用命令
|
# 常用命令
|
||||||
source activate 3in1Py310
|
source activate 3in1Py310
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
数据滤波模块
|
数据滤波模块
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import time
|
||||||
import threading
|
import threading
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
from logs.log import algo_log
|
from logs.log import algo_log
|
||||||
@@ -10,7 +11,7 @@ from logs.log import algo_log
|
|||||||
class FilterRingBuffer:
|
class FilterRingBuffer:
|
||||||
def __init__(self, n_chan, n_points):
|
def __init__(self, n_chan, n_points):
|
||||||
"""
|
"""
|
||||||
初始化纯数据环形缓存
|
初始化纯数据环形缓存(线程安全)
|
||||||
:param n_chan: 通道数
|
:param n_chan: 通道数
|
||||||
:param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致)
|
:param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致)
|
||||||
"""
|
"""
|
||||||
@@ -18,11 +19,9 @@ class FilterRingBuffer:
|
|||||||
self.n_points = n_points
|
self.n_points = n_points
|
||||||
|
|
||||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
|
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
|
||||||
self.current_ptr = 0 # 写入指针
|
self.current_ptr = 0 # 写入指针:指向下一个要写入的位置
|
||||||
self.total_samples = 0 # 已写入总点数
|
self.total_samples = 0 # 已写入总点数
|
||||||
|
self.lock = threading.Lock() # 线程安全锁
|
||||||
# 线程安全锁(多线程环境必须)
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
|
|
||||||
def appendBuffer(self, data):
|
def appendBuffer(self, data):
|
||||||
"""
|
"""
|
||||||
@@ -34,7 +33,7 @@ class FilterRingBuffer:
|
|||||||
if n == 0:
|
if n == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 环形写入逻辑
|
# 环形写入逻辑:指针到末尾则绕回
|
||||||
write_end = self.current_ptr + n
|
write_end = self.current_ptr + n
|
||||||
if write_end <= self.n_points:
|
if write_end <= self.n_points:
|
||||||
self.buffer[:, self.current_ptr:write_end] = data
|
self.buffer[:, self.current_ptr:write_end] = data
|
||||||
@@ -43,13 +42,14 @@ class FilterRingBuffer:
|
|||||||
self.buffer[:, self.current_ptr:] = data[:, :split]
|
self.buffer[:, self.current_ptr:] = data[:, :split]
|
||||||
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
||||||
|
|
||||||
# 更新指针和计数
|
# 更新指针(取模保证环形)和计数(不超过缓存总长度)
|
||||||
self.current_ptr = write_end % self.n_points
|
self.current_ptr = write_end % self.n_points
|
||||||
self.total_samples = min(self.total_samples + n, self.n_points)
|
self.total_samples = min(self.total_samples + n, self.n_points)
|
||||||
|
|
||||||
def getData(self, count):
|
def getData(self, count):
|
||||||
"""
|
"""
|
||||||
从读指针位置读取count个点(与paradigmRingBuffer接口一致)
|
从最新位置向前读取count个点(环形读取)
|
||||||
|
核心逻辑:current_ptr是下一个写入位置 → 最新数据在current_ptr之前
|
||||||
:param count: 读取点数
|
:param count: 读取点数
|
||||||
:return: np.ndarray, shape=(n_chan, count)
|
:return: np.ndarray, shape=(n_chan, count)
|
||||||
"""
|
"""
|
||||||
@@ -58,13 +58,14 @@ class FilterRingBuffer:
|
|||||||
if count == 0:
|
if count == 0:
|
||||||
return np.zeros((self.n_chan, 0))
|
return np.zeros((self.n_chan, 0))
|
||||||
|
|
||||||
# 环形读取逻辑(与paradigmRingBuffer完全相同)
|
# 环形读取:end是当前写入指针(最新数据的下一位),start是end - count
|
||||||
end = self.current_ptr
|
end = self.current_ptr
|
||||||
start = end - count
|
start = end - count
|
||||||
if start >= 0:
|
if start >= 0:
|
||||||
return self.buffer[:, start:end].copy()
|
return self.buffer[:, start:end].copy()
|
||||||
else:
|
else:
|
||||||
part1 = self.buffer[:, start:]
|
# 跨环形边界:前半部分从缓存末尾取,后半部分从开头取
|
||||||
|
part1 = self.buffer[:, start:] # start为负,等价于n_points + start
|
||||||
part2 = self.buffer[:, :end]
|
part2 = self.buffer[:, :end]
|
||||||
return np.concatenate((part1, part2), axis=1)
|
return np.concatenate((part1, part2), axis=1)
|
||||||
|
|
||||||
@@ -72,7 +73,7 @@ class FilterRingBuffer:
|
|||||||
"""
|
"""
|
||||||
扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口)
|
扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口)
|
||||||
:param n: 点数
|
:param n: 点数
|
||||||
:return: np.ndarray, shape=(n_chan, n)
|
:return: np.ndarray, shape=(n_chan, n) | None(数据不足时)
|
||||||
"""
|
"""
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if self.total_samples < n:
|
if self.total_samples < n:
|
||||||
@@ -93,43 +94,37 @@ class FilterRingBuffer:
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||||||
# 可替换任意缓存实现,只要实现appendBuffer、get_latest_n_points接口
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class SlidingFilter:
|
class SlidingFilter(threading.Thread):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
ring_buffer: FilterRingBuffer,
|
||||||
n_chan=66,
|
n_chan=66,
|
||||||
srate=250,
|
srate=250,
|
||||||
buffer_sec=5,
|
|
||||||
window_sec=3,
|
window_sec=3,
|
||||||
step_sec=0.2,
|
step_sec=0.2, # 200ms滑动步长
|
||||||
packet_size=5
|
packet_size=5
|
||||||
):
|
):
|
||||||
"""
|
super().__init__(daemon=True)
|
||||||
初始化滑动滤波器
|
|
||||||
:param n_chan: 通道数
|
|
||||||
:param srate: 采样率
|
|
||||||
:param buffer_sec: 总缓存时长(秒)
|
|
||||||
:param window_sec: 滤波窗口时长(秒)
|
|
||||||
:param step_sec: 滑动步长/输出时长(秒)
|
|
||||||
:param packet_size: 每包数据点数(20ms一包=5点)
|
|
||||||
"""
|
|
||||||
# 核心参数
|
# 核心参数
|
||||||
self.n_chan = n_chan
|
self.n_chan = n_chan
|
||||||
self.srate = srate
|
self.srate = srate
|
||||||
self.buffer_size = int(srate * buffer_sec)
|
self.step_sec = step_sec # 200ms滑动步长
|
||||||
self.window_size = int(srate * window_sec)
|
self.window_sec = window_sec # 3秒窗口
|
||||||
self.step_size = int(srate * step_sec)
|
self.step_sec = step_sec # 200ms滑动步长
|
||||||
|
self.window_size = int(srate * window_sec) # 3秒点数:250*3=750
|
||||||
|
self.step_size = int(srate * step_sec) # 200ms点数:250*0.2=50
|
||||||
self.packet_size = packet_size
|
self.packet_size = packet_size
|
||||||
|
|
||||||
# 初始化纯数据缓存(解耦核心)
|
# 关联ZMQServer的环形缓存(解耦:仅依赖接口)
|
||||||
self.buffer = FilterRingBuffer(n_chan, self.buffer_size)
|
self.ring_buffer = ring_buffer
|
||||||
|
# 线程控制
|
||||||
|
self.running = threading.Event()
|
||||||
|
self.running.set()
|
||||||
|
# 滤波结果回调(外部可注册,获取滤波后的数据)
|
||||||
|
self.filter_result_callback = None
|
||||||
|
|
||||||
# 滤波触发计数器
|
# 预计算滤波器系数(仅执行一次)
|
||||||
self.packet_count = 0
|
|
||||||
self.ready_to_filter = False
|
|
||||||
|
|
||||||
# 预计算滤波器系数
|
|
||||||
self._init_filters()
|
self._init_filters()
|
||||||
|
|
||||||
def _init_filters(self):
|
def _init_filters(self):
|
||||||
@@ -145,65 +140,60 @@ class SlidingFilter:
|
|||||||
)
|
)
|
||||||
self.a_bp = np.array([1.0])
|
self.a_bp = np.array([1.0])
|
||||||
|
|
||||||
def append_and_check_trigger(self, raw_data):
|
def _filter_window_data(self, window_data):
|
||||||
"""
|
"""对3秒窗口数据执行滤波,返回无边界效应的200ms数据"""
|
||||||
追加单包原始数据并检查是否触发滤波
|
|
||||||
:param raw_data: 上位机原始数据,shape=(packet_size, n_chan)
|
|
||||||
:return: bool: 是否触发本次滤波
|
|
||||||
"""
|
|
||||||
# 转置为标准格式:(通道数, 点数)
|
|
||||||
data = raw_data.T.astype(np.float64)
|
|
||||||
|
|
||||||
# 写入缓存(纯缓存操作)
|
|
||||||
self.buffer.appendBuffer(data)
|
|
||||||
|
|
||||||
# 更新包计数器
|
|
||||||
self.packet_count += 1
|
|
||||||
|
|
||||||
# 检查滤波触发条件:数据≥窗口长度 且 累计满一个步长的包数
|
|
||||||
packets_per_step = int(self.step_size / self.packet_size) # 10包=200ms
|
|
||||||
if (self.buffer.GetDataLenCount() >= self.window_size
|
|
||||||
and self.packet_count >= packets_per_step):
|
|
||||||
self.packet_count = 0
|
|
||||||
self.ready_to_filter = True
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def filter_and_get_output(self):
|
|
||||||
"""
|
|
||||||
执行滤波并返回无边界效应的输出数据
|
|
||||||
:return: np.ndarray: 滤波后数据,shape=(n_chan, step_size)
|
|
||||||
"""
|
|
||||||
if not self.ready_to_filter:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 获取最新的完整滤波窗口数据
|
|
||||||
window_data = self.buffer.get_latest_n_points(self.window_size)
|
|
||||||
if window_data is None:
|
|
||||||
self.ready_to_filter = False
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 零相位滤波(无延迟,无边界效应)
|
# 零相位滤波(无延迟,无边界效应)
|
||||||
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
||||||
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
||||||
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
|
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
|
||||||
|
|
||||||
# 提取倒数第二个步长的数据(完全避开两端边界效应)
|
# 提取倒数第二个200ms的数据(完全避开两端边界效应)
|
||||||
|
# 窗口长度750,步长50 → start=750-100=650,end=750-50=700
|
||||||
start_idx = self.window_size - 2 * self.step_size
|
start_idx = self.window_size - 2 * self.step_size
|
||||||
end_idx = self.window_size - self.step_size
|
end_idx = self.window_size - self.step_size
|
||||||
output_data = filtered[:, start_idx:end_idx].copy()
|
output_data = filtered[:, start_idx:end_idx].copy()
|
||||||
|
|
||||||
# 重置触发标志
|
|
||||||
self.ready_to_filter = False
|
|
||||||
|
|
||||||
return output_data
|
return output_data
|
||||||
|
|
||||||
def reset(self):
|
def run(self):
|
||||||
"""重置滤波器和缓存"""
|
"""线程主逻辑:精确200ms触发一次滤波"""
|
||||||
self.buffer.resetAllPara()
|
# 精确定时核心:基于perf_counter计算下一次执行时间,补偿sleep误差
|
||||||
self.packet_count = 0
|
interval = self.step_sec # 200ms = 0.2秒
|
||||||
self.ready_to_filter = False
|
next_run_time = time.perf_counter()
|
||||||
|
|
||||||
def get_buffer_length(self):
|
while self.running.is_set():
|
||||||
"""获取当前缓存数据长度"""
|
# 1. 等待到下一次执行时间(精确定时)
|
||||||
return self.buffer.GetDataLenCount()
|
current_time = time.perf_counter()
|
||||||
|
if current_time < next_run_time:
|
||||||
|
time.sleep(next_run_time - current_time)
|
||||||
|
next_run_time += interval # 补偿:下次执行时间基于上一次目标时间
|
||||||
|
else:
|
||||||
|
# 若超时(如滤波耗时超过200ms),重置下一次时间(避免累积误差)
|
||||||
|
algo_log("滤波耗时超过200ms,定时偏移", level='debug')
|
||||||
|
next_run_time = time.perf_counter() + interval
|
||||||
|
|
||||||
|
# 2. 执行滤波逻辑
|
||||||
|
try:
|
||||||
|
# 获取最新的3秒窗口数据
|
||||||
|
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
|
||||||
|
if window_data is None:
|
||||||
|
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}点", level='debug')
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 滤波并提取无边界效应的200ms数据
|
||||||
|
filtered_data = self._filter_window_data(window_data)
|
||||||
|
|
||||||
|
# 回调返回结果(外部可处理)
|
||||||
|
if self.filter_result_callback is not None:
|
||||||
|
self.filter_result_callback(filtered_data[:64, :]) # 只发送前64通道数据
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"滤波执行异常: {e}", level='error')
|
||||||
|
|
||||||
|
def set_result_callback(self, callback):
|
||||||
|
"""注册滤波结果回调函数"""
|
||||||
|
self.filter_result_callback = callback
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止滤波线程"""
|
||||||
|
self.running.clear()
|
||||||
|
self.join(timeout=1)
|
||||||
|
|||||||
341
Zmq/zmqServer.py
341
Zmq/zmqServer.py
@@ -1,3 +1,4 @@
|
|||||||
|
# -*-coding:utf-8 -*-
|
||||||
import ast
|
import ast
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import threading
|
import threading
|
||||||
@@ -7,7 +8,6 @@ from typing import Dict
|
|||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
|
|
||||||
# from Device.SunnyLinker import SunnyLinker64
|
|
||||||
from Zmq.dataBuffer import ParadigmRingBuffer
|
from Zmq.dataBuffer import ParadigmRingBuffer
|
||||||
from Zmq.filterProcess import FilterRingBuffer
|
from Zmq.filterProcess import FilterRingBuffer
|
||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
@@ -21,63 +21,68 @@ class zmqServer(threading.Thread):
|
|||||||
self.device_info = device_info
|
self.device_info = device_info
|
||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.cmd_port = cmd_port # 命令交互端口
|
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
||||||
self.data_port = data_port # 数据接收端口
|
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
# 原有业务状态变量
|
# 原有业务状态变量
|
||||||
# self.get_Impedance = False # 是否返回阻抗值
|
self.open_Impedance = False #当前系统处于阻抗检测状态
|
||||||
self.open_Impedance = False # 是否开启阻抗检测功能
|
self.StartDecode = False
|
||||||
self.StartDecode = False # false 停止解码,true=开始解码
|
self.StartTrain = False
|
||||||
self.StartTrain = False # False未进入训练状态,True处于训练状态
|
self.state_mode = None
|
||||||
self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态
|
self.currentLabel = -1
|
||||||
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
|
self.IsExitApp = False
|
||||||
self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。
|
|
||||||
# self.getReport = False # 获取训练报告内容
|
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
|
|
||||||
# 范式数据缓存
|
# 双环形缓冲区
|
||||||
self.paradigmBuffer = ParadigmRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
|
self.paradigmBuffer = ParadigmRingBuffer(
|
||||||
self.filterBuffer = FilterRingBuffer(self.device_info['channel_nums'], self.device_info['sample_rate'] * 10)
|
self.device_info['channel_nums'],
|
||||||
self.paradigmBufferLock= threading.Lock()
|
self.device_info['sample_rate'] * 10
|
||||||
|
)
|
||||||
|
self.filterBuffer = FilterRingBuffer(
|
||||||
|
self.device_info['channel_nums'],
|
||||||
|
self.device_info['sample_rate'] * 10
|
||||||
|
)
|
||||||
|
self.paradigmBufferLock = threading.Lock()
|
||||||
|
self.filterBufferLock = threading.Lock()
|
||||||
|
|
||||||
|
# ZMQ上下文与套接字
|
||||||
# 命令与数据通信
|
|
||||||
self.context = zmq.Context()
|
self.context = zmq.Context()
|
||||||
# 指令通道 (8099) - ROUTER:短JSON命令,低频率
|
|
||||||
|
# 8099命令端口:ROUTER
|
||||||
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
self.cmd_socket = self.context.socket(zmq.ROUTER)
|
||||||
# 通用套接字选项:仍在 SocketOption 中
|
|
||||||
self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100)
|
self.cmd_socket.setsockopt(zmq.SocketOption.RCVHWM, 100)
|
||||||
self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100)
|
self.cmd_socket.setsockopt(zmq.SocketOption.SNDHWM, 100)
|
||||||
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
self.cmd_socket.bind(f"tcp://{self.host}:{cmd_port}")
|
||||||
|
|
||||||
# 数据通道 (8100) - ROUTER:高频脑电二进制流
|
# 8100数据端口:ROUTER
|
||||||
self.data_socket = self.context.socket(zmq.ROUTER)
|
self.data_socket = self.context.socket(zmq.ROUTER)
|
||||||
self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500)
|
self.data_socket.setsockopt(zmq.SocketOption.RCVHWM, 500)
|
||||||
|
self.data_socket.setsockopt(zmq.SocketOption.SNDHWM, 100) # 添加发送高水位线
|
||||||
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
self.data_socket.bind(f"tcp://{self.host}:{data_port}")
|
||||||
|
|
||||||
# Poller 轮训器(保持不变)
|
# Poller轮询器
|
||||||
self.poller = zmq.Poller()
|
self.poller = zmq.Poller()
|
||||||
self.poller.register(self.cmd_socket, zmq.POLLIN)
|
self.poller.register(self.cmd_socket, zmq.POLLIN)
|
||||||
self.poller.register(self.data_socket, zmq.POLLIN)
|
self.poller.register(self.data_socket, zmq.POLLIN)
|
||||||
|
|
||||||
# 业务变量
|
# 业务变量
|
||||||
self.targetFreqs = []
|
self.targetFreqs = []
|
||||||
self.changeTarget = False # 更换目标频率
|
self.changeTarget = False
|
||||||
# self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化
|
self.labels = [0x01, 0x02, 0x03]
|
||||||
self.labels = [0x01, 0x02,0x03]
|
self.decoder_switch = False
|
||||||
self.decoder_switch = False #更换解码器
|
self.decoder_class = None
|
||||||
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
|
|
||||||
|
|
||||||
# 客户端管理 - 区分命令/数据客户端
|
# 客户端管理(单客户端场景)
|
||||||
self.cmd_clients = set() # 命令端口客户端ID
|
self.cmd_clients = set()
|
||||||
self.data_clients = set() # 数据端口客户端ID
|
self.data_clients = set()
|
||||||
self.send_queue = queue.Queue() # 发送队列(仅用于命令端口广播)
|
self.current_data_client = None # 唯一数据客户端身份,用于发送滤波结果
|
||||||
|
|
||||||
|
# 发送队列(双端口分离)
|
||||||
|
self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列
|
||||||
|
self.data_send_queue = queue.Queue() # 8100端口滤波数据队列
|
||||||
|
|
||||||
# 范式buffer参数, 事件检测相关
|
# 范式buffer与事件检测参数
|
||||||
self._event_lock = threading.Lock()
|
|
||||||
|
|
||||||
self.predict_event = 99
|
self.predict_event = 99
|
||||||
self.events = [1, 2, self.predict_event]
|
self.events = [1, 2, self.predict_event]
|
||||||
self.latency = 50
|
self.latency = 50
|
||||||
@@ -98,60 +103,131 @@ class zmqServer(threading.Thread):
|
|||||||
self.event_inner_idx = -1
|
self.event_inner_idx = -1
|
||||||
self.interval_inited = False
|
self.interval_inited = False
|
||||||
|
|
||||||
|
|
||||||
def interval_init(self, decoder_class):
|
def interval_init(self, decoder_class):
|
||||||
if decoder_class == 'ssmvep':
|
if decoder_class == 'ssmvep':
|
||||||
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch]
|
||||||
self.train_epoch = [int(self.interval_epoch[0]),
|
self.train_epoch = [
|
||||||
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
|
int(self.interval_epoch[0]),
|
||||||
self.latency = (self.interval_epoch[
|
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])
|
||||||
1] + 0.1 * self.device_info['sample_rate']) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
|
]
|
||||||
|
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
||||||
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
||||||
|
|
||||||
elif decoder_class == 'mi':
|
elif decoder_class == 'mi':
|
||||||
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # epoch截取信息
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch]
|
||||||
self.train_epoch = self.interval_epoch.copy()
|
self.train_epoch = self.interval_epoch.copy()
|
||||||
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;
|
self.latency = self.interval_epoch[1] // 5
|
||||||
self.train_latency = self.latency
|
self.train_latency = self.latency
|
||||||
|
|
||||||
print('时间窗:', (interval_epoch))
|
algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO")
|
||||||
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
|
self.count_events: Dict[str, int] = {}
|
||||||
self.event_inner_idx = -1 # event在5位数据包内部的idx
|
self.event_inner_idx = -1
|
||||||
self.epoch_finished = False # 接收epoch是否完整
|
self.epoch_finished = False
|
||||||
self.pack_contain_event = False # 当前包是否含有event
|
self.pack_contain_event = False
|
||||||
self.predict_event = 99
|
self.predict_event = 99
|
||||||
self.events = [1, 2, self.predict_event]
|
self.events = [1, 2, self.predict_event]
|
||||||
self.interval_inited = True
|
self.interval_inited = True
|
||||||
|
|
||||||
|
# -------------------------- 8099端口:命令结果广播 --------------------------
|
||||||
def broadcast_message(self, method, params):
|
def broadcast_message(self, method, params):
|
||||||
"""Put message into queue to be sent to all command clients"""
|
"""
|
||||||
self.send_queue.put((method, params))
|
向所有8099端口客户端广播JSON格式的命令结果
|
||||||
|
用于:解码结果、训练状态、错误提示、进度通知等
|
||||||
|
"""
|
||||||
|
self.cmd_send_queue.put((method, params))
|
||||||
|
|
||||||
def _handle_cmd_message(self, frames):
|
def _process_cmd_send_queue(self):
|
||||||
"""处理命令端口消息(原有命令交互逻辑)"""
|
"""处理8099端口发送队列,在主线程执行(保证ZMQ线程安全)"""
|
||||||
if len(frames) < 3:
|
while not self.cmd_send_queue.empty():
|
||||||
|
method, params = self.cmd_send_queue.get()
|
||||||
|
if not self.cmd_clients:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||||
|
|
||||||
|
algo_log(f"发送命令结果: {msg}", level="DEBUG")
|
||||||
|
|
||||||
|
# 广播到所有命令客户端
|
||||||
|
for client_id in list(self.cmd_clients):
|
||||||
|
try:
|
||||||
|
self.cmd_socket.send_multipart([client_id, b"", msg_bytes])
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"向命令客户端{client_id}发送失败: {e}", level="ERROR")
|
||||||
|
self.cmd_clients.discard(client_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"命令结果打包失败: {e}", level="ERROR")
|
||||||
|
|
||||||
|
# -------------------------- 8100端口:滤波结果发送 --------------------------
|
||||||
|
def send_filtered_data(self, filtered_data):
|
||||||
|
"""
|
||||||
|
向8100端口客户端发送二进制格式的滤波结果
|
||||||
|
用于:上位机实时绘图的脑电波形数据
|
||||||
|
:param filtered_data: 滤波后数据,shape=(通道数, 50),float64格式
|
||||||
|
"""
|
||||||
|
if self.current_data_client is None:
|
||||||
|
algo_log("数据客户端未连接,跳过滤波数据发送", level="WARNING")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 转置为上位机需要的[50, 通道数]格式
|
||||||
|
filtered_data = filtered_data.T.astype(np.float32)
|
||||||
|
send_buf = filtered_data.tobytes()
|
||||||
|
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG")
|
||||||
|
self.data_send_queue.put(send_buf)
|
||||||
|
|
||||||
|
def _process_data_send_queue(self):
|
||||||
|
"""处理8100端口发送队列,在主线程执行(保证ZMQ线程安全)"""
|
||||||
|
while not self.data_send_queue.empty():
|
||||||
|
send_buf = self.data_send_queue.get()
|
||||||
|
if self.current_data_client is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 标准ROUTER发送格式:[客户端ID, 空分隔帧, 数据帧]
|
||||||
|
self.data_socket.send_multipart([
|
||||||
|
self.current_data_client,
|
||||||
|
b"",
|
||||||
|
send_buf
|
||||||
|
])
|
||||||
|
algo_log(f"发送滤波数据成功,长度: {len(send_buf)}字节", level="DEBUG")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"发送滤波数据失败: {e}", level="ERROR")
|
||||||
|
# 客户端断开,重置身份
|
||||||
|
self.current_data_client = None
|
||||||
|
self.data_clients.clear()
|
||||||
|
|
||||||
|
# -------------------------- 命令端口消息处理 --------------------------
|
||||||
|
def _handle_cmd_message(self, frames):
|
||||||
|
"""处理8099端口JSON命令消息"""
|
||||||
|
if len(frames) < 3:
|
||||||
|
algo_log(f"无效命令帧:长度不足3帧,实际{len(frames)}", level="ERROR")
|
||||||
|
return
|
||||||
|
|
||||||
ident, _, message_bytes = frames[:3]
|
ident, _, message_bytes = frames[:3]
|
||||||
|
|
||||||
# 注册新的命令客户端
|
# 注册新的命令客户端
|
||||||
if ident not in self.cmd_clients:
|
if ident not in self.cmd_clients:
|
||||||
self.cmd_clients.add(ident)
|
self.cmd_clients.add(ident)
|
||||||
algo_log(f"New CMD Client Connected: {ident} (port: {self.cmd_port})")
|
algo_log(f"新命令客户端连接成功: {ident}", level="INFO")
|
||||||
|
|
||||||
# 解析消息
|
# 解析JSON命令
|
||||||
try:
|
try:
|
||||||
message = json.loads(message_bytes.decode('utf-8'))
|
message = json.loads(message_bytes.decode('utf-8'))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
algo_log(f"Invalid JSON from CMD client {ident}")
|
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
|
||||||
|
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
|
||||||
return
|
return
|
||||||
algo_log(f"Received CMD request: {message}")
|
|
||||||
|
|
||||||
|
algo_log(f"收到命令: {message}", level="INFO")
|
||||||
method = message.get("method")
|
method = message.get("method")
|
||||||
params = message.get("params")
|
params = message.get("params")
|
||||||
|
|
||||||
# 原有命令处理逻辑
|
# 命令处理逻辑
|
||||||
if method == "sync":
|
if method == "sync":
|
||||||
self.state_mode = 'sync'
|
self.state_mode = 'sync'
|
||||||
elif method == "targetFreqs":
|
elif method == "targetFreqs":
|
||||||
@@ -163,108 +239,89 @@ class zmqServer(threading.Thread):
|
|||||||
self.changeTarget = True
|
self.changeTarget = True
|
||||||
elif method == "decoderClass":
|
elif method == "decoderClass":
|
||||||
if not isinstance(params, str):
|
if not isinstance(params, str):
|
||||||
algo_log(f"decoderClass must be a str")
|
algo_log(f"decoderClass必须是字符串")
|
||||||
return
|
return
|
||||||
if params != self.decoder_class:
|
if params != self.decoder_class:
|
||||||
self.decoder_class = params
|
self.decoder_class = params
|
||||||
self.decoder_switch = True
|
self.decoder_switch = True
|
||||||
elif method == "train":#训练状态
|
elif method == "train":
|
||||||
self.state_mode = 'train'
|
self.state_mode = 'train'
|
||||||
self.StartTrain = True
|
self.StartTrain = True
|
||||||
self.currentLabel = params # 当前刺激端的训练标签
|
self.currentLabel = params
|
||||||
# self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
elif method == "predict":
|
||||||
elif method == "predict":#预测状态
|
|
||||||
self.state_mode = 'predict'
|
self.state_mode = 'predict'
|
||||||
if params == 1: #开始解码
|
if params == 1: #开始解码
|
||||||
self.StartDecode = True
|
self.StartDecode = True
|
||||||
# self.sunnyLinker.push_trigger(0x63)
|
|
||||||
elif params == 2: #停止解码
|
elif params == 2: #停止解码
|
||||||
self.IsExitApp = True
|
self.IsExitApp = True
|
||||||
self.running = False
|
self.running = False
|
||||||
elif method == "rest": #休息状态
|
elif method == "rest":
|
||||||
self.state_mode = 'rest'
|
self.state_mode = 'rest'
|
||||||
elif method == "impedance":
|
elif method == "impedance":
|
||||||
if params == 1:
|
if params == 1:
|
||||||
self.open_Impedance = True # 开启阻抗
|
self.open_Impedance = True
|
||||||
# self.get_Impedance = True # 返回阻抗
|
|
||||||
elif params == 2:
|
elif params == 2:
|
||||||
self.open_Impedance = False # 关闭阻抗
|
self.open_Impedance = False
|
||||||
else:
|
else:
|
||||||
algo_log(f"未知命令:{method}", level="WARNING")
|
self.broadcast_message("error", {"code": 404, "message": f"未知命令: {method}"})
|
||||||
|
|
||||||
# elif method == "getReport":
|
|
||||||
# self.getReport = True
|
|
||||||
|
|
||||||
# elif params == 2:
|
|
||||||
# self.open_Impedance = False # 关闭阻抗
|
|
||||||
# self.get_Impedance = False # 停止返回阻抗
|
|
||||||
|
|
||||||
|
# -------------------------- 数据端口消息处理 --------------------------
|
||||||
def _handle_data_message(self, frames):
|
def _handle_data_message(self, frames):
|
||||||
"""
|
"""处理8100端口二进制脑电数据消息"""
|
||||||
处理8100端口原始脑电二进制数据
|
algo_log(f"收到数据帧,总帧数:{len(frames)}", level="DEBUG", record_once=True)
|
||||||
固定格式:上位机发送 (5,66) float32 二维数组字节流(已转换为微伏物理量)→ 转置为 (66,5) 写入双缓冲区
|
# 然后再进行解析
|
||||||
"""
|
if len(frames) == 4:
|
||||||
# 1. 校验ZMQ消息帧完整性(ROUTER接收DEALER消息的帧格式:[客户端ID, 发送方ID, 空帧, 数据帧])
|
# 你的上位机格式
|
||||||
if len(frames) < 4: # 至少需要4帧
|
ident, sender_ident, empty_sep, data_bytes = frames[:4]
|
||||||
algo_log(f"Invalid data frame: 帧数量不足,期望≥4,实际{len(frames)}", level="ERROR")
|
elif len(frames) == 3:
|
||||||
|
# 标准格式
|
||||||
|
ident, empty_sep, data_bytes = frames[:3]
|
||||||
|
else:
|
||||||
return
|
return
|
||||||
|
# 注册新的数据客户端(单客户端场景,自动覆盖旧身份)
|
||||||
# 2. 正确解析帧(适配DEALER→ROUTER的帧格式)
|
if ident not in self.data_clients:
|
||||||
client_ident, sender_ident, empty_sep, data_bytes = frames[:4]
|
self.data_clients.clear() # 单客户端,只保留最新连接
|
||||||
if empty_sep != b'': # 校验空分隔帧
|
self.data_clients.add(ident)
|
||||||
algo_log(f"Invalid frame separator: 期望空字节,实际{empty_sep}", level="ERROR")
|
self.current_data_client = ident
|
||||||
return
|
algo_log(f"新数据客户端连接成功: {ident}", level="INFO")
|
||||||
|
|
||||||
# 3. 客户端管理(单客户端场景,自动更新最新身份)
|
|
||||||
if client_ident not in self.data_clients:
|
|
||||||
self.data_clients.add(client_ident)
|
|
||||||
self.current_data_client = client_ident # 保存唯一客户端身份,用于后续回复滤波结果
|
|
||||||
print(f"[INFO] 新数据客户端连接成功:{client_ident}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 4. 精确长度校验(核心:固定(5,66) float32 = 5*66*4=1320字节)
|
# 精确长度校验
|
||||||
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4 # 每个float32占4字节
|
EXPECTED_BYTES = self.device_info['frame_points'] * self.device_info['channel_nums'] * 4
|
||||||
if len(data_bytes) != EXPECTED_BYTES:
|
if len(data_bytes) != EXPECTED_BYTES:
|
||||||
algo_log(f"[ERROR] 数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
|
algo_log(f"数据长度错误:期望{EXPECTED_BYTES}字节,实际{len(data_bytes)}字节", level="ERROR")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 5. 零拷贝二进制解析 + 维度转换
|
# 零拷贝解析 + 维度转换
|
||||||
|
|
||||||
data_np = np.frombuffer(data_bytes, dtype=np.float32)
|
data_np = np.frombuffer(data_bytes, dtype=np.float32)
|
||||||
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
|
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
|
||||||
data_np = data_np.T.astype(np.float64)
|
data_np = data_np.T.astype(np.float64)
|
||||||
|
|
||||||
# 6. 写入滤波缓冲区
|
# 写入滤波缓冲区
|
||||||
|
with self.filterBufferLock:
|
||||||
self.filterBuffer.appendBuffer(data_np)
|
self.filterBuffer.appendBuffer(data_np)
|
||||||
|
|
||||||
# 7. 写入范式缓冲区
|
# 写入范式缓冲区
|
||||||
try:
|
|
||||||
with self.paradigmBufferLock:
|
with self.paradigmBufferLock:
|
||||||
if self.interval_inited:
|
if self.interval_inited:
|
||||||
self.epoch_finished = self.detect_event(data_np)
|
self.epoch_finished = self.detect_event(data_np)
|
||||||
if self.pack_contain_event:
|
if self.pack_contain_event:
|
||||||
self.paradigmBuffer.resetAllPara() # 检测到当前pack含有event,清除ringbuffer中之前的数据
|
self.paradigmBuffer.resetAllPara()
|
||||||
self.paradigmBuffer.appendBuffer(data_np)
|
self.paradigmBuffer.appendBuffer(data_np)
|
||||||
if self.epoch_finished:
|
if self.epoch_finished:
|
||||||
time.sleep(0.005)
|
algo_log('Epoch采集完成: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG")
|
||||||
algo_log('epoch_finished: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG")
|
|
||||||
else:
|
else:
|
||||||
self.paradigmBuffer.appendBuffer(data_np)
|
self.paradigmBuffer.appendBuffer(data_np)
|
||||||
except Exception as e:
|
|
||||||
print("锁:写入异常",e)
|
|
||||||
self.paradigmBuffer.appendBuffer(data_np)
|
|
||||||
|
|
||||||
# algo_log(f"数据写入成功:shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"数据处理失败:{str(e)}", level="ERROR")
|
algo_log(f"数据处理失败: {str(e)}", level="ERROR")
|
||||||
if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG':
|
if IniRead('system', 'algo_log_level', 'INFO') == 'DEBUG':
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
# 检测是否含有标签
|
# -------------------------- 事件检测 --------------------------
|
||||||
def detect_event(self, samples):
|
def detect_event(self, samples):
|
||||||
self.pack_contain_event = False
|
self.pack_contain_event = False
|
||||||
|
# 第65通道为事件通道
|
||||||
events = np.array(samples[-2])[0].tolist()
|
events = np.array(samples[-2])[0].tolist()
|
||||||
for idx, event in enumerate(events):
|
for idx, event in enumerate(events):
|
||||||
if int(event) in self.events:
|
if int(event) in self.events:
|
||||||
@@ -281,76 +338,54 @@ class zmqServer(threading.Thread):
|
|||||||
self.count_events[new_key] = self.train_latency + 1
|
self.count_events[new_key] = self.train_latency + 1
|
||||||
self.event_inner_idx = idx
|
self.event_inner_idx = idx
|
||||||
self.pack_contain_event = True
|
self.pack_contain_event = True
|
||||||
|
|
||||||
|
# 倒计时并清理过期事件
|
||||||
drop_items = []
|
drop_items = []
|
||||||
for key, value in self.count_events.items():
|
for key, value in self.count_events.items():
|
||||||
value = value - 1
|
value -= 1
|
||||||
if value == 0:
|
if value == 0:
|
||||||
drop_items.append(key)
|
drop_items.append(key)
|
||||||
self.count_events[key] = value
|
self.count_events[key] = value
|
||||||
|
|
||||||
for key in drop_items:
|
for key in drop_items:
|
||||||
del self.count_events[key]
|
del self.count_events[key]
|
||||||
|
|
||||||
if drop_items:
|
if drop_items:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
# -------------------------- 主循环 --------------------------
|
||||||
|
|
||||||
|
|
||||||
def _process_send_queue(self):
|
|
||||||
"""处理发送队列,向所有命令客户端广播消息"""
|
|
||||||
while not self.send_queue.empty():
|
|
||||||
method, params = self.send_queue.get()
|
|
||||||
if self.cmd_clients:
|
|
||||||
try:
|
|
||||||
msg = {'method': method, 'params': params}
|
|
||||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
|
||||||
|
|
||||||
# 打印日志(隐藏大尺寸数据)
|
|
||||||
if method in ['single_trial_plot', 'miReport']:
|
|
||||||
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
|
|
||||||
else:
|
|
||||||
print(f"Sending CMD message: {msg}")
|
|
||||||
|
|
||||||
# 广播到所有命令客户端
|
|
||||||
for client_id in list(self.cmd_clients):
|
|
||||||
try:
|
|
||||||
self.cmd_socket.send_multipart([client_id, b'', msg_bytes])
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error sending to CMD client {client_id}: {e}")
|
|
||||||
self.cmd_clients.discard(client_id) # 移除失效客户端
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error preparing broadcast: {e}")
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.running = True
|
self.running = True
|
||||||
algo_log(f"algo ZMQ Server started - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}", level="INFO")
|
algo_log(f"ZMQ服务器启动成功 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while self.running:
|
while self.running:
|
||||||
# 1. 处理发送队列(命令端口广播)
|
# 1. 处理两个端口的发送队列(必须在主线程执行)
|
||||||
self._process_send_queue()
|
self._process_cmd_send_queue()
|
||||||
|
self._process_data_send_queue()
|
||||||
|
|
||||||
# 2. 轮训监听两个Socket的输入事件
|
# 2. 轮询监听两个端口的输入事件
|
||||||
socks = dict(self.poller.poll(50))
|
socks = dict(self.poller.poll(50))
|
||||||
|
|
||||||
# 处理命令端口消息
|
# 处理8099命令端口消息
|
||||||
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
if self.cmd_socket in socks and socks[self.cmd_socket] == zmq.POLLIN:
|
||||||
frames = self.cmd_socket.recv_multipart()
|
frames = self.cmd_socket.recv_multipart()
|
||||||
self._handle_cmd_message(frames)
|
self._handle_cmd_message(frames)
|
||||||
|
|
||||||
# 处理数据端口消息
|
# 处理8100数据端口消息
|
||||||
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
||||||
frames = self.data_socket.recv_multipart()
|
frames = self.data_socket.recv_multipart()
|
||||||
self._handle_data_message(frames)
|
self._handle_data_message(frames)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Server error occurred: {e}")
|
algo_log(f"服务器主循环异常: {e}", level="ERROR")
|
||||||
finally:
|
finally:
|
||||||
self.running = False
|
self.running = False
|
||||||
# 关闭所有Socket和上下文
|
# 优雅关闭所有资源
|
||||||
self.cmd_socket.close()
|
self.cmd_socket.close()
|
||||||
self.data_socket.close()
|
self.data_socket.close()
|
||||||
self.context.term()
|
self.context.term()
|
||||||
print("Server sockets and context closed.")
|
algo_log("ZMQ服务器已关闭", level="INFO")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""显式关闭服务器"""
|
"""显式关闭服务器"""
|
||||||
@@ -358,10 +393,10 @@ class zmqServer(threading.Thread):
|
|||||||
self.cmd_socket.close()
|
self.cmd_socket.close()
|
||||||
self.data_socket.close()
|
self.data_socket.close()
|
||||||
self.context.term()
|
self.context.term()
|
||||||
print(f"Server closed explicitly - CMD Port: {self.cmd_port}, DATA Port: {self.data_port}")
|
algo_log(f"服务器已显式关闭 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 初始化并启动服务器(默认cmd=8099, data=8100)
|
# 初始化并启动服务器
|
||||||
server = zmqServer()
|
server = zmqServer()
|
||||||
server.start()
|
server.start()
|
||||||
|
|
||||||
@@ -370,5 +405,5 @@ if __name__ == '__main__':
|
|||||||
while server.running:
|
while server.running:
|
||||||
threading.Event().wait(1)
|
threading.Event().wait(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("Received KeyboardInterrupt, stopping server...")
|
algo_log("收到键盘中断信号,正在停止服务器...", level="INFO")
|
||||||
server.stop()
|
server.stop()
|
||||||
@@ -42,7 +42,7 @@ if __name__ == "__main__":
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
device_info= get_device_info(1)
|
device_info= get_device_info(1)
|
||||||
algo_log(f"device_info: {device_info}", level="INFO")
|
algo_log(f"device_info: {device_info}", level="DEBUG")
|
||||||
decoder = Decoder_main(device_info=device_info)
|
decoder = Decoder_main(device_info=device_info)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user