Files
bci_algo/filter_test.py

421 lines
16 KiB
Python
Raw Normal View History

2026-06-09 16:46:07 +08:00
# -*- coding: utf-8 -*-
"""
2026-06-09 18:30:56 +08:00
脑电滤波服务 8100端口测试工具统计逻辑专项优化版
优化点
1. 5秒预热(250个发包)预热结束后才启动丢包/数据统计
2. 业务比例0.02s发1包200ms收1包 10 个发包对应 1 个回包
3. 通道校验发送(5,66) 仅对比前64通道接收(50,64)全通道比对
4. 区分全局总包数 / 有效统计区间包数理论收包数实际收包数丢包数丢包率
5. 新增64通道整体数据均值/极值比对校验数据有效性
通信规范send_multipart([client_id, b"", data_buf]) 三帧报文服务端 recv_multipart 长度=3
2026-06-09 16:46:07 +08:00
"""
import sys
import time
import threading
import logging
import traceback
from collections import deque
import numpy as np
import zmq
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
# ===================== 全局前置修复Matplotlib中文字体 & 负号显示 =====================
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "WenQuanYi Micro Hei"]
2026-06-09 18:30:56 +08:00
plt.rcParams["axes.unicode_minus"] = False
2026-06-09 16:46:07 +08:00
2026-06-09 18:30:56 +08:00
# ===================== 【1. 全局业务固定参数(核心统计规则)】 =====================
2026-06-09 16:46:07 +08:00
# ZMQ 服务端配置
2026-06-10 08:24:20 +08:00
ZMQ_SERVER_IP = "127.0.0.1"
2026-06-09 16:46:07 +08:00
ZMQ_SERVER_PORT = 8100
ZMQ_SOCKET_TIMEOUT = 3000 # 套接字超时(ms)
2026-06-09 18:30:56 +08:00
POLL_TIMEOUT = 10 # Poll轮询超时(ms)
# 时序 & 统计核心规则(严格对齐现场业务)
SEND_INTERVAL = 0.02 # 上位机发包间隔20ms/包
RECV_INTERVAL = 0.2 # 服务端回包间隔200ms/包
PREHEAT_SECONDS = 5.0 # 滤波缓存预热时长5秒
# 计算:预热需要的发包总数 = 预热时长 / 单包发送间隔
PREHEAT_SEND_PACKS = int(PREHEAT_SECONDS / SEND_INTERVAL) # 5 / 0.02 = 250 包
# 收发包比例每多少个发包对应1个回包
PACK_RATIO = int(RECV_INTERVAL / SEND_INTERVAL) # 0.2 / 0.02 = 10
# 数据报文形状
PKG_SEND_SHAPE = (5, 66) # 发送包 (点数, 总通道)
PKG_RECV_SHAPE = (50, 64) # 回包 (点数, 有效脑电通道)
SAMPLE_RATE = 250
# 通道定义对比仅使用前64路脑电通道
CH_EEG_VALID = 64 # 共同对比通道数0~63
2026-06-09 16:46:07 +08:00
CH_EVENT = 64
CH_RESERVED = 65
2026-06-09 18:30:56 +08:00
# ZMQ 三帧报文固定字段
2026-06-09 16:46:07 +08:00
CLIENT_ID = b"test_client_001"
EMPTY_FRAME = b""
2026-06-09 18:30:56 +08:00
# 仿真信号配置
2026-06-09 16:46:07 +08:00
TARGET_CHANNEL = 0
2026-06-10 08:24:20 +08:00
SIGNAL_FREQ_LIST = [3, 13]
2026-06-09 16:46:07 +08:00
SIGNAL_AMP = 1.8
NOISE_GAUSSIAN_AMP = 0.4
NOISE_POWER50_AMP = 0.3
EVENT_LABEL_VAL = 1
RESERVED_VAL = 0.0
# 可视化配置
MAX_PLOT_POINTS = 800
PLOT_REFRESH_INTERVAL = 80
FFT_N_POINTS = 256
PLOT_X_LIMIT_FREQ = (0, 60)
# 运行控制
MAX_RUN_SECONDS = None
ENABLE_RECONNECT = True
PRINT_STAT_INTERVAL = 5.0
2026-06-09 18:30:56 +08:00
# ===================== 【2. 全局变量 + 统计结构体(重构统计逻辑)】 =====================
2026-06-09 16:46:07 +08:00
g_running = threading.Event()
g_running.set()
data_lock = threading.Lock()
2026-06-09 18:30:56 +08:00
# 绘图缓冲区
2026-06-09 16:46:07 +08:00
raw_data_buf = deque(maxlen=MAX_PLOT_POINTS)
filt_data_buf = deque(maxlen=MAX_PLOT_POINTS)
2026-06-09 18:30:56 +08:00
# ===================== 全新统计变量(区分预热/正式统计) =====================
2026-06-09 16:46:07 +08:00
stat = {
2026-06-09 18:30:56 +08:00
# 全局总包数(包含预热包)
"total_send": 0,
"total_recv": 0,
# 有效统计区间预热250包之后
"valid_send": 0, # 有效发包数
"valid_recv": 0, # 有效收包数
"theo_recv": 0, # 理论应收到包数 = valid_send // PACK_RATIO
# 运行时间
2026-06-09 16:46:07 +08:00
"start_time": time.perf_counter(),
2026-06-09 18:30:56 +08:00
"last_print_time": time.perf_counter(),
# 数据校验缓存保存最新一包原始64通道数据用于和回包比对
"latest_raw_64ch": None
2026-06-09 16:46:07 +08:00
}
# ===================== 【3. 日志配置】 =====================
def init_logger():
log_format = "%(asctime)s | %(levelname)-8s | %(message)s"
logging.basicConfig(
level=logging.INFO,
format=log_format,
datefmt="%Y-%m-%d %H:%M:%S"
)
return logging.getLogger("FilterTest")
logger = init_logger()
# ===================== 【4. 仿真脑电数据生成 (5,66)】 =====================
def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
2026-06-09 18:30:56 +08:00
"""生成单包 (5,66) 仿真数据"""
2026-06-09 16:46:07 +08:00
n_point, n_chan = PKG_SEND_SHAPE
base_t = pkt_idx * n_point / SAMPLE_RATE
t_arr = base_t + np.arange(n_point) / SAMPLE_RATE
data = np.zeros((n_point, n_chan), dtype=np.float64)
2026-06-09 18:30:56 +08:00
# 64路脑电信号
for ch in range(CH_EEG_VALID):
2026-06-09 16:46:07 +08:00
sig = 0.0
for freq in SIGNAL_FREQ_LIST:
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
data[:, ch] = sig
2026-06-09 18:30:56 +08:00
# 事件通道、保留通道
2026-06-09 16:46:07 +08:00
data[:, CH_EVENT] = EVENT_LABEL_VAL
data[:, CH_RESERVED] = RESERVED_VAL
return data
2026-06-09 18:30:56 +08:00
# ===================== 【5. ZMQ 核心IO线程单连接+Poller保留原有通信逻辑】 =====================
2026-06-09 16:46:07 +08:00
def zmq_io_thread():
context = zmq.Context()
pkt_index = 0
send_interval = SEND_INTERVAL
2026-06-09 18:30:56 +08:00
logger.info(f"滤波预热配置:{PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 个发包后开始统计")
logger.info(f"收发比例:每 {PACK_RATIO} 个发包 → 1 个滤波回包")
2026-06-09 16:46:07 +08:00
while g_running.is_set():
try:
sock = context.socket(zmq.DEALER)
sock.setsockopt(zmq.RCVTIMEO, ZMQ_SOCKET_TIMEOUT)
sock.setsockopt(zmq.SNDTIMEO, ZMQ_SOCKET_TIMEOUT)
sock.connect(f"tcp://{ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
logger.info(f"ZMQ 连接成功 -> {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)
next_send_ts = time.perf_counter()
while g_running.is_set():
2026-06-09 18:30:56 +08:00
# 全局运行时长限制
2026-06-09 16:46:07 +08:00
if MAX_RUN_SECONDS is not None:
run_sec = time.perf_counter() - stat["start_time"]
if run_sec > MAX_RUN_SECONDS:
logger.info(f"已到达设定运行时长 {MAX_RUN_SECONDS}s停止任务")
return
2026-06-09 18:30:56 +08:00
# ========== 1. 轮询接收服务端回包 ==========
2026-06-09 16:46:07 +08:00
socks_ready = dict(poller.poll(POLL_TIMEOUT))
if sock in socks_ready:
frames = sock.recv_multipart()
if not frames:
continue
recv_bytes = frames[-1]
if not recv_bytes:
continue
2026-06-09 18:30:56 +08:00
# 解析回包 (50,64)
2026-06-09 16:46:07 +08:00
filt_data = np.frombuffer(recv_bytes, dtype=np.float64)
expect_size = PKG_RECV_SHAPE[0] * PKG_RECV_SHAPE[1]
if filt_data.size != expect_size:
logger.warning(f"回包长度异常:实际{filt_data.size},预期{expect_size}")
continue
filt_data = filt_data.reshape(PKG_RECV_SHAPE)
2026-06-09 18:30:56 +08:00
# 全局收包计数
stat["total_recv"] += 1
# 仅预热完成后,计入有效统计收包
if stat["total_send"] > PREHEAT_SEND_PACKS:
stat["valid_recv"] += 1
# 写入绘图缓冲区
2026-06-09 16:46:07 +08:00
with data_lock:
filt_data_buf.extend(filt_data[:, TARGET_CHANNEL])
2026-06-09 18:30:56 +08:00
# ---------- 新增64通道数据比对发包前64通道 <-> 回包64通道 ----------
raw_64ch = stat["latest_raw_64ch"]
if raw_64ch is not None:
raw_mean = np.mean(raw_64ch)
filt_mean = np.mean(filt_data)
raw_amp = np.max(np.abs(raw_64ch))
filt_amp = np.max(np.abs(filt_data))
logger.debug(
f"【通道数据比对】原始64通道均值:{raw_mean:.4f} 幅值:{raw_amp:.4f} | "
f"滤波后均值:{filt_mean:.4f} 幅值:{filt_amp:.4f}"
2026-06-09 16:46:07 +08:00
)
2026-06-09 18:30:56 +08:00
# ========== 2. 精准定时发送数据包 ==========
2026-06-09 16:46:07 +08:00
current_ts = time.perf_counter()
if current_ts >= next_send_ts:
2026-06-09 18:30:56 +08:00
# 生成(5,66)仿真包
2026-06-09 16:46:07 +08:00
pkt_data = generate_eeg_packet(pkt_index)
pkt_index += 1
send_buf = pkt_data.tobytes()
2026-06-09 18:30:56 +08:00
# 标准三帧Multipart发送
2026-06-09 16:46:07 +08:00
sock.send_multipart([CLIENT_ID, EMPTY_FRAME, send_buf])
2026-06-09 18:30:56 +08:00
# ---------- 发包计数逻辑(核心优化:预热区分) ----------
stat["total_send"] += 1
# 预热完成后,计入有效发包
if stat["total_send"] > PREHEAT_SEND_PACKS:
stat["valid_send"] += 1
# 计算理论应收包数
stat["theo_recv"] = stat["valid_send"] // PACK_RATIO
# 缓存当前包前64通道用于后续数据比对
stat["latest_raw_64ch"] = pkt_data[:, :CH_EEG_VALID]
# 绘图缓冲区(单通道波形)
2026-06-09 16:46:07 +08:00
with data_lock:
raw_data_buf.extend(pkt_data[:, TARGET_CHANNEL])
2026-06-09 18:30:56 +08:00
# 更新下一次发包时间
2026-06-09 16:46:07 +08:00
next_send_ts += send_interval
2026-06-09 18:30:56 +08:00
# ========== 3. 定时打印统计信息(区分预热/正式统计) ==========
now = time.perf_counter()
if now - stat["last_print_time"] > PRINT_STAT_INTERVAL:
run_sec = now - stat["start_time"]
total_send = stat["total_send"]
total_recv = stat["total_recv"]
# 分支1仍在预热阶段
if total_send <= PREHEAT_SEND_PACKS:
remain = PREHEAT_SEND_PACKS - total_send
logger.info(
f"[预热中] 运行:{run_sec:.1f}s | 已发包:{total_send}/{PREHEAT_SEND_PACKS} | "
f"剩余预热包:{remain} | 暂不统计丢包"
)
# 分支2预热完成进入正式统计
else:
v_send = stat["valid_send"]
v_recv = stat["valid_recv"]
t_recv = stat["theo_recv"]
loss_cnt = t_recv - v_recv
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
logger.info(
f"[正式统计] 运行:{run_sec:.1f}s | "
f"全局总包: 发{total_send}/收{total_recv} | "
f"有效区间: 发{v_send}/应收{t_recv}/实收{v_recv} | "
f"丢包数:{loss_cnt} | 丢包率:{loss_rate:.2f}%"
)
stat["last_print_time"] = now
2026-06-09 16:46:07 +08:00
except zmq.ZMQError as e:
if e.errno == zmq.EAGAIN:
continue
logger.warning(f"ZMQ 连接异常: {e}")
sock.close()
poller.unregister(sock)
if not ENABLE_RECONNECT:
break
logger.info("500ms 后尝试重连...")
time.sleep(0.5)
except Exception as e:
logger.error(f"IO线程未知异常:\n{traceback.format_exc()}")
break
context.term()
logger.info("ZMQ IO 线程已退出")
2026-06-09 18:30:56 +08:00
# ===================== 【6. 可视化绘图(无改动)】 =====================
2026-06-09 16:46:07 +08:00
def init_plot():
fig = plt.figure(figsize=(14, 9))
fig.suptitle(f"脑电滤波测试 | 观测通道: {TARGET_CHANNEL}", fontsize=14)
ax1 = plt.subplot(2, 2, 1)
ax1.set_title("原始输入波形 (含噪声+工频)")
ax1.set_ylabel("幅值")
ax1.grid(True, alpha=0.3)
line_raw, = ax1.plot([], [], color="#1f77b4", linewidth=1)
ax2 = plt.subplot(2, 2, 2)
ax2.set_title("滤波后输出波形")
ax2.set_ylabel("幅值")
ax2.grid(True, alpha=0.3)
line_filt, = ax2.plot([], [], color="#d62728", linewidth=1)
ax3 = plt.subplot(2, 2, 3)
ax3.set_title("原始信号频谱")
ax3.set_xlabel("频率 (Hz)")
ax3.set_xlim(*PLOT_X_LIMIT_FREQ)
ax3.grid(True, alpha=0.3)
line_raw_fft, = ax3.plot([], [], color="#1f77b4")
ax4 = plt.subplot(2, 2, 4)
ax4.set_title("滤波后信号频谱")
ax4.set_xlabel("频率 (Hz)")
ax4.set_xlim(*PLOT_X_LIMIT_FREQ)
ax4.grid(True, alpha=0.3)
line_filt_fft, = ax4.plot([], [], color="#d62728")
plt.tight_layout(rect=[0, 0, 1, 0.96])
return fig, [line_raw, line_filt, line_raw_fft, line_filt_fft], [ax1, ax2, ax3, ax4]
def update_plot(frame, lines, axes):
line_raw, line_filt, line_raw_fft, line_filt_fft = lines
ax1, ax2, ax3, ax4 = axes
with data_lock:
raw_data = list(raw_data_buf)
filt_data = list(filt_data_buf)
if raw_data:
x_raw = np.arange(len(raw_data))
line_raw.set_data(x_raw, raw_data)
ax1.relim()
ax1.autoscale_view()
if filt_data:
x_filt = np.arange(len(filt_data))
line_filt.set_data(x_filt, filt_data)
ax2.relim()
ax2.autoscale_view()
def calc_fft(sig, n_fft):
if len(sig) < n_fft:
return [], []
win = np.hanning(n_fft)
sig_win = sig[-n_fft:] * win
fft_vals = np.fft.fft(sig_win)
fft_amp = np.abs(fft_vals)[:n_fft//2]
freq = np.fft.fftfreq(n_fft, 1/SAMPLE_RATE)[:n_fft//2]
return freq, fft_amp
freq_raw, amp_raw = calc_fft(raw_data, FFT_N_POINTS)
freq_filt, amp_filt = calc_fft(filt_data, FFT_N_POINTS)
line_raw_fft.set_data(freq_raw, amp_raw)
line_filt_fft.set_data(freq_filt, amp_filt)
ax3.relim()
ax3.autoscale_view(scaley=True)
ax4.relim()
ax4.autoscale_view(scaley=True)
return lines
2026-06-09 18:30:56 +08:00
# ===================== 【7. 资源释放 & 最终汇总统计】 =====================
2026-06-09 16:46:07 +08:00
def clean_resource():
g_running.clear()
logger.info("开始停止所有线程...")
time.sleep(0.3)
plt.close("all")
logger.info("资源释放完成")
def main():
2026-06-09 18:30:56 +08:00
logger.info("=" * 70)
logger.info("脑电滤波测试客户端【统计逻辑优化版】启动")
2026-06-09 16:46:07 +08:00
logger.info(f"服务端地址: {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
2026-06-09 18:30:56 +08:00
logger.info(f"发包: {PKG_SEND_SHAPE}({SEND_INTERVAL*1000:.0f}ms) | 回包: {PKG_RECV_SHAPE}({RECV_INTERVAL*1000:.0f}ms)")
logger.info(f"预热规则: {PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 包后开启统计")
logger.info(f"收发比例: 每 {PACK_RATIO} 个发包对应 1 个回包")
logger.info("=" * 70)
2026-06-09 16:46:07 +08:00
2026-06-09 18:30:56 +08:00
# 启动ZMQ收发线程
2026-06-09 16:46:07 +08:00
io_thread = threading.Thread(target=zmq_io_thread, daemon=True, name="ZMQ_IO_Thread")
io_thread.start()
2026-06-09 18:30:56 +08:00
# 启动可视化
2026-06-09 16:46:07 +08:00
fig, lines, axes = init_plot()
ani = FuncAnimation(
fig, update_plot,
fargs=(lines, axes),
interval=PLOT_REFRESH_INTERVAL,
blit=True,
cache_frame_data=False
)
try:
plt.show()
except KeyboardInterrupt:
logger.info("收到 Ctrl+C 中断信号,准备退出")
finally:
2026-06-09 18:30:56 +08:00
# 输出最终完整汇总报表
2026-06-09 16:46:07 +08:00
run_total = time.perf_counter() - stat["start_time"]
2026-06-09 18:30:56 +08:00
total_send = stat["total_send"]
total_recv = stat["total_recv"]
v_send = stat["valid_send"]
v_recv = stat["valid_recv"]
t_recv = stat["theo_recv"]
loss_cnt = t_recv - v_recv
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
logger.info(f"\n{'='*50} 最终运行汇总 {'='*50}")
2026-06-09 16:46:07 +08:00
logger.info(f"总运行时长: {run_total:.1f} s")
2026-06-09 18:30:56 +08:00
logger.info(f"【全局总包数】发送: {total_send} | 接收: {total_recv}")
logger.info(f"【有效统计区间(跳过预热{PREHEAT_SEND_PACKS}包)】")
logger.info(f" 有效发包: {v_send} | 理论应收包: {t_recv} | 实际收包: {v_recv}")
logger.info(f" 总丢包数: {loss_cnt} | 整体丢包率: {loss_rate:.2f} %")
logger.info(f"{'='*106}")
2026-06-09 16:46:07 +08:00
clean_resource()
sys.exit(0)
if __name__ == "__main__":
main()