351 lines
12 KiB
Python
351 lines
12 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
脑电滤波服务 8100端口测试工具【最终修复版】
|
|||
|
|
修复:1. Matplotlib中文字体乱码 2. ZMQ双连接收不到数据问题
|
|||
|
|
通信规范:
|
|||
|
|
上位机 -> 服务端:send_multipart([client_id, b"", data_buf]) 共3帧
|
|||
|
|
服务端 recv_multipart() 帧长度 = 3
|
|||
|
|
时序:每20ms(0.02s)发送一包 (5,66),服务端200ms回传 (50,64)
|
|||
|
|
"""
|
|||
|
|
import sys
|
|||
|
|
import time
|
|||
|
|
import threading
|
|||
|
|
import logging
|
|||
|
|
import traceback
|
|||
|
|
from collections import deque
|
|||
|
|
import numpy as np
|
|||
|
|
import zmq
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
from matplotlib.animation import FuncAnimation
|
|||
|
|
|
|||
|
|
# ===================== 全局前置:修复Matplotlib中文字体 & 负号显示 =====================
|
|||
|
|
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "WenQuanYi Micro Hei"]
|
|||
|
|
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示异常
|
|||
|
|
|
|||
|
|
# ===================== 【1. 全局可配置参数区】 =====================
|
|||
|
|
# ZMQ 服务端配置
|
|||
|
|
ZMQ_SERVER_IP = "127.0.0.1"
|
|||
|
|
ZMQ_SERVER_PORT = 8100
|
|||
|
|
ZMQ_SOCKET_TIMEOUT = 3000 # 套接字超时(ms)
|
|||
|
|
POLL_TIMEOUT = 10 # Poll轮询超时(ms),不影响发包时序
|
|||
|
|
|
|||
|
|
# 数据报文配置(严格对齐业务)
|
|||
|
|
PKG_SEND_SHAPE = (5, 66) # 发送包 shape (点数, 总通道)
|
|||
|
|
PKG_RECV_SHAPE = (50, 64) # 滤波回包 shape (点数, 脑电通道)
|
|||
|
|
SEND_INTERVAL = 0.02 # 上位机发包间隔 20ms
|
|||
|
|
SAMPLE_RATE = 250 # 采样率 Hz
|
|||
|
|
|
|||
|
|
# 通道定义
|
|||
|
|
CH_EEG = 64
|
|||
|
|
CH_EVENT = 64
|
|||
|
|
CH_RESERVED = 65
|
|||
|
|
|
|||
|
|
# ZMQ 三帧报文固定字段(和你服务端代码完全一致)
|
|||
|
|
CLIENT_ID = b"test_client_001"
|
|||
|
|
EMPTY_FRAME = b""
|
|||
|
|
|
|||
|
|
# 仿真信号配置(可自由调参测试滤波)
|
|||
|
|
TARGET_CHANNEL = 0
|
|||
|
|
SIGNAL_FREQ_LIST = [10.0, 22.0]
|
|||
|
|
SIGNAL_AMP = 1.8
|
|||
|
|
NOISE_GAUSSIAN_AMP = 0.4
|
|||
|
|
NOISE_POWER50_AMP = 0.3
|
|||
|
|
EVENT_LABEL_VAL = 1
|
|||
|
|
RESERVED_VAL = 0.0
|
|||
|
|
|
|||
|
|
# 可视化配置
|
|||
|
|
MAX_PLOT_POINTS = 800
|
|||
|
|
PLOT_REFRESH_INTERVAL = 80
|
|||
|
|
FFT_N_POINTS = 256
|
|||
|
|
PLOT_X_LIMIT_FREQ = (0, 60)
|
|||
|
|
|
|||
|
|
# 运行控制
|
|||
|
|
MAX_RUN_SECONDS = None
|
|||
|
|
ENABLE_RECONNECT = True
|
|||
|
|
PRINT_STAT_INTERVAL = 5.0
|
|||
|
|
|
|||
|
|
# ===================== 【2. 全局变量 & 线程安全】 =====================
|
|||
|
|
g_running = threading.Event()
|
|||
|
|
g_running.set()
|
|||
|
|
data_lock = threading.Lock()
|
|||
|
|
|
|||
|
|
# 绘图数据缓冲区
|
|||
|
|
raw_data_buf = deque(maxlen=MAX_PLOT_POINTS)
|
|||
|
|
filt_data_buf = deque(maxlen=MAX_PLOT_POINTS)
|
|||
|
|
|
|||
|
|
# 运行统计
|
|||
|
|
stat = {
|
|||
|
|
"send_cnt": 0,
|
|||
|
|
"recv_cnt": 0,
|
|||
|
|
"start_time": time.perf_counter(),
|
|||
|
|
"last_print_time": time.perf_counter()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# ===================== 【3. 日志配置】 =====================
|
|||
|
|
def init_logger():
|
|||
|
|
log_format = "%(asctime)s | %(levelname)-8s | %(message)s"
|
|||
|
|
logging.basicConfig(
|
|||
|
|
level=logging.INFO,
|
|||
|
|
format=log_format,
|
|||
|
|
datefmt="%Y-%m-%d %H:%M:%S"
|
|||
|
|
)
|
|||
|
|
return logging.getLogger("FilterTest")
|
|||
|
|
|
|||
|
|
logger = init_logger()
|
|||
|
|
|
|||
|
|
# ===================== 【4. 仿真脑电数据生成 (5,66)】 =====================
|
|||
|
|
def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
|
|||
|
|
"""生成单包 (5,66) 仿真数据:脑电+噪声+工频+事件通道+保留通道"""
|
|||
|
|
n_point, n_chan = PKG_SEND_SHAPE
|
|||
|
|
base_t = pkt_idx * n_point / SAMPLE_RATE
|
|||
|
|
t_arr = base_t + np.arange(n_point) / SAMPLE_RATE
|
|||
|
|
|
|||
|
|
data = np.zeros((n_point, n_chan), dtype=np.float64)
|
|||
|
|
|
|||
|
|
# 64路脑电:多频信号 + 50Hz工频 + 高斯白噪声
|
|||
|
|
for ch in range(CH_EEG):
|
|||
|
|
sig = 0.0
|
|||
|
|
for freq in SIGNAL_FREQ_LIST:
|
|||
|
|
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
|
|||
|
|
sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
|
|||
|
|
sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
|
|||
|
|
data[:, ch] = sig
|
|||
|
|
|
|||
|
|
# 事件通道、保留通道赋值
|
|||
|
|
data[:, CH_EVENT] = EVENT_LABEL_VAL
|
|||
|
|
data[:, CH_RESERVED] = RESERVED_VAL
|
|||
|
|
return data
|
|||
|
|
|
|||
|
|
# ===================== 【5. 核心修复:单DEALER连接 + Poller 同时收发】 =====================
|
|||
|
|
def zmq_io_thread():
|
|||
|
|
"""
|
|||
|
|
唯一ZMQ工作线程:单个DEALER连接,同时发包+收包(对齐真实上位机)
|
|||
|
|
使用 Poller 多路复用,避免阻塞、超时报错
|
|||
|
|
"""
|
|||
|
|
context = zmq.Context()
|
|||
|
|
pkt_index = 0
|
|||
|
|
send_interval = SEND_INTERVAL
|
|||
|
|
|
|||
|
|
while g_running.is_set():
|
|||
|
|
try:
|
|||
|
|
# 新建 DEALER 套接字(全局唯一连接)
|
|||
|
|
sock = context.socket(zmq.DEALER)
|
|||
|
|
sock.setsockopt(zmq.RCVTIMEO, ZMQ_SOCKET_TIMEOUT)
|
|||
|
|
sock.setsockopt(zmq.SNDTIMEO, ZMQ_SOCKET_TIMEOUT)
|
|||
|
|
sock.connect(f"tcp://{ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
|||
|
|
logger.info(f"ZMQ 连接成功 -> {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
|||
|
|
|
|||
|
|
# 注册Poller:监听当前套接字的可读事件
|
|||
|
|
poller = zmq.Poller()
|
|||
|
|
poller.register(sock, zmq.POLLIN)
|
|||
|
|
|
|||
|
|
# 精准发包计时(消除sleep漂移)
|
|||
|
|
next_send_ts = time.perf_counter()
|
|||
|
|
|
|||
|
|
while g_running.is_set():
|
|||
|
|
# 1. 运行时长限制判断
|
|||
|
|
if MAX_RUN_SECONDS is not None:
|
|||
|
|
run_sec = time.perf_counter() - stat["start_time"]
|
|||
|
|
if run_sec > MAX_RUN_SECONDS:
|
|||
|
|
logger.info(f"已到达设定运行时长 {MAX_RUN_SECONDS}s,停止任务")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 2. Poll 轮询:有数据就接收,无数据继续执行发包逻辑
|
|||
|
|
socks_ready = dict(poller.poll(POLL_TIMEOUT))
|
|||
|
|
if sock in socks_ready:
|
|||
|
|
# ========== 接收服务端回包 (multipart) ==========
|
|||
|
|
frames = sock.recv_multipart()
|
|||
|
|
if not frames:
|
|||
|
|
continue
|
|||
|
|
# 取最后一帧为有效滤波数据
|
|||
|
|
recv_bytes = frames[-1]
|
|||
|
|
if not recv_bytes:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 解析为 (50,64) float64
|
|||
|
|
filt_data = np.frombuffer(recv_bytes, dtype=np.float64)
|
|||
|
|
expect_size = PKG_RECV_SHAPE[0] * PKG_RECV_SHAPE[1]
|
|||
|
|
if filt_data.size != expect_size:
|
|||
|
|
logger.warning(f"回包长度异常:实际{filt_data.size},预期{expect_size}")
|
|||
|
|
continue
|
|||
|
|
filt_data = filt_data.reshape(PKG_RECV_SHAPE)
|
|||
|
|
|
|||
|
|
# 统计 + 写入绘图缓冲区
|
|||
|
|
stat["recv_cnt"] += 1
|
|||
|
|
with data_lock:
|
|||
|
|
filt_data_buf.extend(filt_data[:, TARGET_CHANNEL])
|
|||
|
|
|
|||
|
|
# 定时打印运行状态
|
|||
|
|
now = time.perf_counter()
|
|||
|
|
if now - stat["last_print_time"] > PRINT_STAT_INTERVAL:
|
|||
|
|
run_sec = now - stat["start_time"]
|
|||
|
|
loss_rate = (stat["send_cnt"] - stat["recv_cnt"]) / stat["send_cnt"] * 100 if stat["send_cnt"] > 0 else 0.0
|
|||
|
|
logger.info(
|
|||
|
|
f"运行:{run_sec:.1f}s | 发包:{stat['send_cnt']} | 收包:{stat['recv_cnt']} | 丢包率:{loss_rate:.2f}%"
|
|||
|
|
)
|
|||
|
|
stat["last_print_time"] = now
|
|||
|
|
|
|||
|
|
# 3. 精准定时发包(严格20ms间隔)
|
|||
|
|
current_ts = time.perf_counter()
|
|||
|
|
if current_ts >= next_send_ts:
|
|||
|
|
# 生成 (5,66) 仿真数据包
|
|||
|
|
pkt_data = generate_eeg_packet(pkt_index)
|
|||
|
|
pkt_index += 1
|
|||
|
|
send_buf = pkt_data.tobytes()
|
|||
|
|
|
|||
|
|
# ========== 三帧Multipart发送(和你服务端代码完全一致) ==========
|
|||
|
|
sock.send_multipart([CLIENT_ID, EMPTY_FRAME, send_buf])
|
|||
|
|
|
|||
|
|
# 统计 + 写入原始数据缓冲区
|
|||
|
|
stat["send_cnt"] += 1
|
|||
|
|
with data_lock:
|
|||
|
|
raw_data_buf.extend(pkt_data[:, TARGET_CHANNEL])
|
|||
|
|
|
|||
|
|
# 更新下一次发包时间戳
|
|||
|
|
next_send_ts += send_interval
|
|||
|
|
|
|||
|
|
except zmq.ZMQError as e:
|
|||
|
|
# 区分正常超时 和 网络异常
|
|||
|
|
if e.errno == zmq.EAGAIN:
|
|||
|
|
continue
|
|||
|
|
logger.warning(f"ZMQ 连接异常: {e}")
|
|||
|
|
sock.close()
|
|||
|
|
poller.unregister(sock)
|
|||
|
|
if not ENABLE_RECONNECT:
|
|||
|
|
break
|
|||
|
|
logger.info("500ms 后尝试重连...")
|
|||
|
|
time.sleep(0.5)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"IO线程未知异常:\n{traceback.format_exc()}")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
context.term()
|
|||
|
|
logger.info("ZMQ IO 线程已退出")
|
|||
|
|
|
|||
|
|
# ===================== 【6. 可视化绘图(无逻辑改动,已前置修复字体)】 =====================
|
|||
|
|
def init_plot():
|
|||
|
|
fig = plt.figure(figsize=(14, 9))
|
|||
|
|
fig.suptitle(f"脑电滤波测试 | 观测通道: {TARGET_CHANNEL}", fontsize=14)
|
|||
|
|
|
|||
|
|
ax1 = plt.subplot(2, 2, 1)
|
|||
|
|
ax1.set_title("原始输入波形 (含噪声+工频)")
|
|||
|
|
ax1.set_ylabel("幅值")
|
|||
|
|
ax1.grid(True, alpha=0.3)
|
|||
|
|
line_raw, = ax1.plot([], [], color="#1f77b4", linewidth=1)
|
|||
|
|
|
|||
|
|
ax2 = plt.subplot(2, 2, 2)
|
|||
|
|
ax2.set_title("滤波后输出波形")
|
|||
|
|
ax2.set_ylabel("幅值")
|
|||
|
|
ax2.grid(True, alpha=0.3)
|
|||
|
|
line_filt, = ax2.plot([], [], color="#d62728", linewidth=1)
|
|||
|
|
|
|||
|
|
ax3 = plt.subplot(2, 2, 3)
|
|||
|
|
ax3.set_title("原始信号频谱")
|
|||
|
|
ax3.set_xlabel("频率 (Hz)")
|
|||
|
|
ax3.set_xlim(*PLOT_X_LIMIT_FREQ)
|
|||
|
|
ax3.grid(True, alpha=0.3)
|
|||
|
|
line_raw_fft, = ax3.plot([], [], color="#1f77b4")
|
|||
|
|
|
|||
|
|
ax4 = plt.subplot(2, 2, 4)
|
|||
|
|
ax4.set_title("滤波后信号频谱")
|
|||
|
|
ax4.set_xlabel("频率 (Hz)")
|
|||
|
|
ax4.set_xlim(*PLOT_X_LIMIT_FREQ)
|
|||
|
|
ax4.grid(True, alpha=0.3)
|
|||
|
|
line_filt_fft, = ax4.plot([], [], color="#d62728")
|
|||
|
|
|
|||
|
|
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
|||
|
|
return fig, [line_raw, line_filt, line_raw_fft, line_filt_fft], [ax1, ax2, ax3, ax4]
|
|||
|
|
|
|||
|
|
def update_plot(frame, lines, axes):
|
|||
|
|
line_raw, line_filt, line_raw_fft, line_filt_fft = lines
|
|||
|
|
ax1, ax2, ax3, ax4 = axes
|
|||
|
|
|
|||
|
|
with data_lock:
|
|||
|
|
raw_data = list(raw_data_buf)
|
|||
|
|
filt_data = list(filt_data_buf)
|
|||
|
|
|
|||
|
|
# 时域波形
|
|||
|
|
if raw_data:
|
|||
|
|
x_raw = np.arange(len(raw_data))
|
|||
|
|
line_raw.set_data(x_raw, raw_data)
|
|||
|
|
ax1.relim()
|
|||
|
|
ax1.autoscale_view()
|
|||
|
|
|
|||
|
|
if filt_data:
|
|||
|
|
x_filt = np.arange(len(filt_data))
|
|||
|
|
line_filt.set_data(x_filt, filt_data)
|
|||
|
|
ax2.relim()
|
|||
|
|
ax2.autoscale_view()
|
|||
|
|
|
|||
|
|
# 频谱计算(汉宁窗减少频谱泄露)
|
|||
|
|
def calc_fft(sig, n_fft):
|
|||
|
|
if len(sig) < n_fft:
|
|||
|
|
return [], []
|
|||
|
|
win = np.hanning(n_fft)
|
|||
|
|
sig_win = sig[-n_fft:] * win
|
|||
|
|
fft_vals = np.fft.fft(sig_win)
|
|||
|
|
fft_amp = np.abs(fft_vals)[:n_fft//2]
|
|||
|
|
freq = np.fft.fftfreq(n_fft, 1/SAMPLE_RATE)[:n_fft//2]
|
|||
|
|
return freq, fft_amp
|
|||
|
|
|
|||
|
|
freq_raw, amp_raw = calc_fft(raw_data, FFT_N_POINTS)
|
|||
|
|
freq_filt, amp_filt = calc_fft(filt_data, FFT_N_POINTS)
|
|||
|
|
|
|||
|
|
line_raw_fft.set_data(freq_raw, amp_raw)
|
|||
|
|
line_filt_fft.set_data(freq_filt, amp_filt)
|
|||
|
|
ax3.relim()
|
|||
|
|
ax3.autoscale_view(scaley=True)
|
|||
|
|
ax4.relim()
|
|||
|
|
ax4.autoscale_view(scaley=True)
|
|||
|
|
|
|||
|
|
return lines
|
|||
|
|
|
|||
|
|
# ===================== 【7. 资源释放 & 主入口】 =====================
|
|||
|
|
def clean_resource():
|
|||
|
|
g_running.clear()
|
|||
|
|
logger.info("开始停止所有线程...")
|
|||
|
|
time.sleep(0.3)
|
|||
|
|
plt.close("all")
|
|||
|
|
logger.info("资源释放完成")
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
logger.info("=" * 60)
|
|||
|
|
logger.info("脑电滤波测试客户端 【修复版】启动")
|
|||
|
|
logger.info(f"服务端地址: {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
|||
|
|
logger.info(f"发包格式: {PKG_SEND_SHAPE} | 间隔: {SEND_INTERVAL*1000:.0f}ms")
|
|||
|
|
logger.info(f"回包格式: {PKG_RECV_SHAPE} | ZMQ三帧报文 [客户端ID, 空帧, 数据帧]")
|
|||
|
|
logger.info("=" * 60)
|
|||
|
|
|
|||
|
|
# 启动唯一ZMQ收发线程
|
|||
|
|
io_thread = threading.Thread(target=zmq_io_thread, daemon=True, name="ZMQ_IO_Thread")
|
|||
|
|
io_thread.start()
|
|||
|
|
|
|||
|
|
# 启动可视化绘图
|
|||
|
|
fig, lines, axes = init_plot()
|
|||
|
|
ani = FuncAnimation(
|
|||
|
|
fig, update_plot,
|
|||
|
|
fargs=(lines, axes),
|
|||
|
|
interval=PLOT_REFRESH_INTERVAL,
|
|||
|
|
blit=True,
|
|||
|
|
cache_frame_data=False
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 主线程阻塞,监听关闭
|
|||
|
|
try:
|
|||
|
|
plt.show()
|
|||
|
|
except KeyboardInterrupt:
|
|||
|
|
logger.info("收到 Ctrl+C 中断信号,准备退出")
|
|||
|
|
finally:
|
|||
|
|
# 输出最终统计
|
|||
|
|
run_total = time.perf_counter() - stat["start_time"]
|
|||
|
|
loss_rate = (stat["send_cnt"] - stat["recv_cnt"]) / stat["send_cnt"] * 100 if stat["send_cnt"] > 0 else 0.0
|
|||
|
|
logger.info(f"\n===== 运行汇总 =====")
|
|||
|
|
logger.info(f"总运行时长: {run_total:.1f} s")
|
|||
|
|
logger.info(f"总发包数: {stat['send_cnt']}")
|
|||
|
|
logger.info(f"总收包数: {stat['recv_cnt']}")
|
|||
|
|
logger.info(f"整体丢包率: {loss_rate:.2f} %")
|
|||
|
|
clean_resource()
|
|||
|
|
sys.exit(0)
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|