Files
bci_algo/filter_test.py
2026-06-09 18:30:56 +08:00

421 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
脑电滤波服务 8100端口测试工具【统计逻辑专项优化版】
优化点:
1. 5秒预热(250个发包),预热结束后才启动丢包/数据统计
2. 业务比例0.02s发1包200ms收1包 → 每 10 个发包对应 1 个回包
3. 通道校验:发送(5,66) 仅对比前64通道接收(50,64)全通道比对
4. 区分:全局总包数 / 有效统计区间包数、理论收包数、实际收包数、丢包数、丢包率
5. 新增64通道整体数据均值/极值比对,校验数据有效性
通信规范send_multipart([client_id, b"", data_buf]) 三帧报文,服务端 recv_multipart 长度=3
"""
import sys
import time
import threading
import logging
import traceback
from collections import deque
import numpy as np
import zmq
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
# ===================== 全局前置修复Matplotlib中文字体 & 负号显示 =====================
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "WenQuanYi Micro Hei"]
plt.rcParams["axes.unicode_minus"] = False
# ===================== 【1. 全局业务固定参数(核心统计规则)】 =====================
# ZMQ 服务端配置
ZMQ_SERVER_IP = "192.168.254.102"
ZMQ_SERVER_PORT = 8100
ZMQ_SOCKET_TIMEOUT = 3000 # 套接字超时(ms)
POLL_TIMEOUT = 10 # Poll轮询超时(ms)
# 时序 & 统计核心规则(严格对齐现场业务)
SEND_INTERVAL = 0.02 # 上位机发包间隔20ms/包
RECV_INTERVAL = 0.2 # 服务端回包间隔200ms/包
PREHEAT_SECONDS = 5.0 # 滤波缓存预热时长5秒
# 计算:预热需要的发包总数 = 预热时长 / 单包发送间隔
PREHEAT_SEND_PACKS = int(PREHEAT_SECONDS / SEND_INTERVAL) # 5 / 0.02 = 250 包
# 收发包比例每多少个发包对应1个回包
PACK_RATIO = int(RECV_INTERVAL / SEND_INTERVAL) # 0.2 / 0.02 = 10
# 数据报文形状
PKG_SEND_SHAPE = (5, 66) # 发送包 (点数, 总通道)
PKG_RECV_SHAPE = (50, 64) # 回包 (点数, 有效脑电通道)
SAMPLE_RATE = 250
# 通道定义对比仅使用前64路脑电通道
CH_EEG_VALID = 64 # 共同对比通道数0~63
CH_EVENT = 64
CH_RESERVED = 65
# ZMQ 三帧报文固定字段
CLIENT_ID = b"test_client_001"
EMPTY_FRAME = b""
# 仿真信号配置
TARGET_CHANNEL = 0
SIGNAL_FREQ_LIST = [3, 10, 36]
SIGNAL_AMP = 1.8
NOISE_GAUSSIAN_AMP = 0.4
NOISE_POWER50_AMP = 0.3
EVENT_LABEL_VAL = 1
RESERVED_VAL = 0.0
# 可视化配置
MAX_PLOT_POINTS = 800
PLOT_REFRESH_INTERVAL = 80
FFT_N_POINTS = 256
PLOT_X_LIMIT_FREQ = (0, 60)
# 运行控制
MAX_RUN_SECONDS = None
ENABLE_RECONNECT = True
PRINT_STAT_INTERVAL = 5.0
# ===================== 【2. 全局变量 + 统计结构体(重构统计逻辑)】 =====================
g_running = threading.Event()
g_running.set()
data_lock = threading.Lock()
# 绘图缓冲区
raw_data_buf = deque(maxlen=MAX_PLOT_POINTS)
filt_data_buf = deque(maxlen=MAX_PLOT_POINTS)
# ===================== 全新统计变量(区分预热/正式统计) =====================
stat = {
# 全局总包数(包含预热包)
"total_send": 0,
"total_recv": 0,
# 有效统计区间预热250包之后
"valid_send": 0, # 有效发包数
"valid_recv": 0, # 有效收包数
"theo_recv": 0, # 理论应收到包数 = valid_send // PACK_RATIO
# 运行时间
"start_time": time.perf_counter(),
"last_print_time": time.perf_counter(),
# 数据校验缓存保存最新一包原始64通道数据用于和回包比对
"latest_raw_64ch": None
}
# ===================== 【3. 日志配置】 =====================
def init_logger():
log_format = "%(asctime)s | %(levelname)-8s | %(message)s"
logging.basicConfig(
level=logging.INFO,
format=log_format,
datefmt="%Y-%m-%d %H:%M:%S"
)
return logging.getLogger("FilterTest")
logger = init_logger()
# ===================== 【4. 仿真脑电数据生成 (5,66)】 =====================
def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
"""生成单包 (5,66) 仿真数据"""
n_point, n_chan = PKG_SEND_SHAPE
base_t = pkt_idx * n_point / SAMPLE_RATE
t_arr = base_t + np.arange(n_point) / SAMPLE_RATE
data = np.zeros((n_point, n_chan), dtype=np.float64)
# 64路脑电信号
for ch in range(CH_EEG_VALID):
sig = 0.0
for freq in SIGNAL_FREQ_LIST:
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
data[:, ch] = sig
# 事件通道、保留通道
data[:, CH_EVENT] = EVENT_LABEL_VAL
data[:, CH_RESERVED] = RESERVED_VAL
return data
# ===================== 【5. ZMQ 核心IO线程单连接+Poller保留原有通信逻辑】 =====================
def zmq_io_thread():
context = zmq.Context()
pkt_index = 0
send_interval = SEND_INTERVAL
logger.info(f"滤波预热配置:{PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 个发包后开始统计")
logger.info(f"收发比例:每 {PACK_RATIO} 个发包 → 1 个滤波回包")
while g_running.is_set():
try:
sock = context.socket(zmq.DEALER)
sock.setsockopt(zmq.RCVTIMEO, ZMQ_SOCKET_TIMEOUT)
sock.setsockopt(zmq.SNDTIMEO, ZMQ_SOCKET_TIMEOUT)
sock.connect(f"tcp://{ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
logger.info(f"ZMQ 连接成功 -> {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)
next_send_ts = time.perf_counter()
while g_running.is_set():
# 全局运行时长限制
if MAX_RUN_SECONDS is not None:
run_sec = time.perf_counter() - stat["start_time"]
if run_sec > MAX_RUN_SECONDS:
logger.info(f"已到达设定运行时长 {MAX_RUN_SECONDS}s停止任务")
return
# ========== 1. 轮询接收服务端回包 ==========
socks_ready = dict(poller.poll(POLL_TIMEOUT))
if sock in socks_ready:
frames = sock.recv_multipart()
if not frames:
continue
recv_bytes = frames[-1]
if not recv_bytes:
continue
# 解析回包 (50,64)
filt_data = np.frombuffer(recv_bytes, dtype=np.float64)
expect_size = PKG_RECV_SHAPE[0] * PKG_RECV_SHAPE[1]
if filt_data.size != expect_size:
logger.warning(f"回包长度异常:实际{filt_data.size},预期{expect_size}")
continue
filt_data = filt_data.reshape(PKG_RECV_SHAPE)
# 全局收包计数
stat["total_recv"] += 1
# 仅预热完成后,计入有效统计收包
if stat["total_send"] > PREHEAT_SEND_PACKS:
stat["valid_recv"] += 1
# 写入绘图缓冲区
with data_lock:
filt_data_buf.extend(filt_data[:, TARGET_CHANNEL])
# ---------- 新增64通道数据比对发包前64通道 <-> 回包64通道 ----------
raw_64ch = stat["latest_raw_64ch"]
if raw_64ch is not None:
raw_mean = np.mean(raw_64ch)
filt_mean = np.mean(filt_data)
raw_amp = np.max(np.abs(raw_64ch))
filt_amp = np.max(np.abs(filt_data))
logger.debug(
f"【通道数据比对】原始64通道均值:{raw_mean:.4f} 幅值:{raw_amp:.4f} | "
f"滤波后均值:{filt_mean:.4f} 幅值:{filt_amp:.4f}"
)
# ========== 2. 精准定时发送数据包 ==========
current_ts = time.perf_counter()
if current_ts >= next_send_ts:
# 生成(5,66)仿真包
pkt_data = generate_eeg_packet(pkt_index)
pkt_index += 1
send_buf = pkt_data.tobytes()
# 标准三帧Multipart发送
sock.send_multipart([CLIENT_ID, EMPTY_FRAME, send_buf])
# ---------- 发包计数逻辑(核心优化:预热区分) ----------
stat["total_send"] += 1
# 预热完成后,计入有效发包
if stat["total_send"] > PREHEAT_SEND_PACKS:
stat["valid_send"] += 1
# 计算理论应收包数
stat["theo_recv"] = stat["valid_send"] // PACK_RATIO
# 缓存当前包前64通道用于后续数据比对
stat["latest_raw_64ch"] = pkt_data[:, :CH_EEG_VALID]
# 绘图缓冲区(单通道波形)
with data_lock:
raw_data_buf.extend(pkt_data[:, TARGET_CHANNEL])
# 更新下一次发包时间
next_send_ts += send_interval
# ========== 3. 定时打印统计信息(区分预热/正式统计) ==========
now = time.perf_counter()
if now - stat["last_print_time"] > PRINT_STAT_INTERVAL:
run_sec = now - stat["start_time"]
total_send = stat["total_send"]
total_recv = stat["total_recv"]
# 分支1仍在预热阶段
if total_send <= PREHEAT_SEND_PACKS:
remain = PREHEAT_SEND_PACKS - total_send
logger.info(
f"[预热中] 运行:{run_sec:.1f}s | 已发包:{total_send}/{PREHEAT_SEND_PACKS} | "
f"剩余预热包:{remain} | 暂不统计丢包"
)
# 分支2预热完成进入正式统计
else:
v_send = stat["valid_send"]
v_recv = stat["valid_recv"]
t_recv = stat["theo_recv"]
loss_cnt = t_recv - v_recv
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
logger.info(
f"[正式统计] 运行:{run_sec:.1f}s | "
f"全局总包: 发{total_send}/收{total_recv} | "
f"有效区间: 发{v_send}/应收{t_recv}/实收{v_recv} | "
f"丢包数:{loss_cnt} | 丢包率:{loss_rate:.2f}%"
)
stat["last_print_time"] = now
except zmq.ZMQError as e:
if e.errno == zmq.EAGAIN:
continue
logger.warning(f"ZMQ 连接异常: {e}")
sock.close()
poller.unregister(sock)
if not ENABLE_RECONNECT:
break
logger.info("500ms 后尝试重连...")
time.sleep(0.5)
except Exception as e:
logger.error(f"IO线程未知异常:\n{traceback.format_exc()}")
break
context.term()
logger.info("ZMQ IO 线程已退出")
# ===================== 【6. 可视化绘图(无改动)】 =====================
def init_plot():
fig = plt.figure(figsize=(14, 9))
fig.suptitle(f"脑电滤波测试 | 观测通道: {TARGET_CHANNEL}", fontsize=14)
ax1 = plt.subplot(2, 2, 1)
ax1.set_title("原始输入波形 (含噪声+工频)")
ax1.set_ylabel("幅值")
ax1.grid(True, alpha=0.3)
line_raw, = ax1.plot([], [], color="#1f77b4", linewidth=1)
ax2 = plt.subplot(2, 2, 2)
ax2.set_title("滤波后输出波形")
ax2.set_ylabel("幅值")
ax2.grid(True, alpha=0.3)
line_filt, = ax2.plot([], [], color="#d62728", linewidth=1)
ax3 = plt.subplot(2, 2, 3)
ax3.set_title("原始信号频谱")
ax3.set_xlabel("频率 (Hz)")
ax3.set_xlim(*PLOT_X_LIMIT_FREQ)
ax3.grid(True, alpha=0.3)
line_raw_fft, = ax3.plot([], [], color="#1f77b4")
ax4 = plt.subplot(2, 2, 4)
ax4.set_title("滤波后信号频谱")
ax4.set_xlabel("频率 (Hz)")
ax4.set_xlim(*PLOT_X_LIMIT_FREQ)
ax4.grid(True, alpha=0.3)
line_filt_fft, = ax4.plot([], [], color="#d62728")
plt.tight_layout(rect=[0, 0, 1, 0.96])
return fig, [line_raw, line_filt, line_raw_fft, line_filt_fft], [ax1, ax2, ax3, ax4]
def update_plot(frame, lines, axes):
line_raw, line_filt, line_raw_fft, line_filt_fft = lines
ax1, ax2, ax3, ax4 = axes
with data_lock:
raw_data = list(raw_data_buf)
filt_data = list(filt_data_buf)
if raw_data:
x_raw = np.arange(len(raw_data))
line_raw.set_data(x_raw, raw_data)
ax1.relim()
ax1.autoscale_view()
if filt_data:
x_filt = np.arange(len(filt_data))
line_filt.set_data(x_filt, filt_data)
ax2.relim()
ax2.autoscale_view()
def calc_fft(sig, n_fft):
if len(sig) < n_fft:
return [], []
win = np.hanning(n_fft)
sig_win = sig[-n_fft:] * win
fft_vals = np.fft.fft(sig_win)
fft_amp = np.abs(fft_vals)[:n_fft//2]
freq = np.fft.fftfreq(n_fft, 1/SAMPLE_RATE)[:n_fft//2]
return freq, fft_amp
freq_raw, amp_raw = calc_fft(raw_data, FFT_N_POINTS)
freq_filt, amp_filt = calc_fft(filt_data, FFT_N_POINTS)
line_raw_fft.set_data(freq_raw, amp_raw)
line_filt_fft.set_data(freq_filt, amp_filt)
ax3.relim()
ax3.autoscale_view(scaley=True)
ax4.relim()
ax4.autoscale_view(scaley=True)
return lines
# ===================== 【7. 资源释放 & 最终汇总统计】 =====================
def clean_resource():
g_running.clear()
logger.info("开始停止所有线程...")
time.sleep(0.3)
plt.close("all")
logger.info("资源释放完成")
def main():
logger.info("=" * 70)
logger.info("脑电滤波测试客户端【统计逻辑优化版】启动")
logger.info(f"服务端地址: {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
logger.info(f"发包: {PKG_SEND_SHAPE}({SEND_INTERVAL*1000:.0f}ms) | 回包: {PKG_RECV_SHAPE}({RECV_INTERVAL*1000:.0f}ms)")
logger.info(f"预热规则: {PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 包后开启统计")
logger.info(f"收发比例: 每 {PACK_RATIO} 个发包对应 1 个回包")
logger.info("=" * 70)
# 启动ZMQ收发线程
io_thread = threading.Thread(target=zmq_io_thread, daemon=True, name="ZMQ_IO_Thread")
io_thread.start()
# 启动可视化
fig, lines, axes = init_plot()
ani = FuncAnimation(
fig, update_plot,
fargs=(lines, axes),
interval=PLOT_REFRESH_INTERVAL,
blit=True,
cache_frame_data=False
)
try:
plt.show()
except KeyboardInterrupt:
logger.info("收到 Ctrl+C 中断信号,准备退出")
finally:
# 输出最终完整汇总报表
run_total = time.perf_counter() - stat["start_time"]
total_send = stat["total_send"]
total_recv = stat["total_recv"]
v_send = stat["valid_send"]
v_recv = stat["valid_recv"]
t_recv = stat["theo_recv"]
loss_cnt = t_recv - v_recv
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
logger.info(f"\n{'='*50} 最终运行汇总 {'='*50}")
logger.info(f"总运行时长: {run_total:.1f} s")
logger.info(f"【全局总包数】发送: {total_send} | 接收: {total_recv}")
logger.info(f"【有效统计区间(跳过预热{PREHEAT_SEND_PACKS}包)】")
logger.info(f" 有效发包: {v_send} | 理论应收包: {t_recv} | 实际收包: {v_recv}")
logger.info(f" 总丢包数: {loss_cnt} | 整体丢包率: {loss_rate:.2f} %")
logger.info(f"{'='*106}")
clean_resource()
sys.exit(0)
if __name__ == "__main__":
main()