From 07560304ca4b02bbbbef8d77b63051c2cb44c97a Mon Sep 17 00:00:00 2001 From: lizhao Date: Tue, 9 Jun 2026 10:57:28 +0800 Subject: [PATCH] del train --- Decoder.py | 39 +++-- README.md | 1 + Zmq/zmqServer.py | 50 +++--- datamock.py | 1 + logs/log.py | 95 +++++++---- system_test.py | 422 +++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 537 insertions(+), 71 deletions(-) create mode 100644 system_test.py diff --git a/Decoder.py b/Decoder.py index 2b8b543..c84977c 100644 --- a/Decoder.py +++ b/Decoder.py @@ -96,7 +96,7 @@ class Decoder_main(threading.Thread): elif decoder_class == 'ssmvep': self.zmqServer.interval_init(decoder_class) self.n_chan = 8 - self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) + self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2] self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 self.single_train = 10 # 单类别数量 self.num_target = 2 # 分类目标数目 @@ -268,26 +268,29 @@ class Decoder_main(threading.Thread): '''训练阶段采集数据''' if self.zmqServer.state_mode == 'train': # 训练状态 - if self.zmqServer.StartTrain: + + + if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \ + self.train_epoch[1] + self.zmqServer.event_inner_idx: + self.currentLabel = self.zmqServer.currentLabel - self.zmqServer.StartTrain = False - if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ - self.train_epoch[1] \ - + self.zmqServer.event_inner_idx: + + print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount()) + trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据 + + print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx]) + trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 + trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[ + 0]:self.zmqServer.event_inner_idx + self.train_epoch[1]] + print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1]) + if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( + self.trainLabel, list) \ + and self.trainLabel.count(self.currentLabel) < self.single_train: + self.trainData.append(trainTrial) + self.trainLabel.append(self.currentLabel) + else: time.sleep(0.0001) return - print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount()) - trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据 - print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx]) - trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 - trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[ - 0]:self.zmqServer.event_inner_idx + self.train_epoch[1]] - print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1]) - if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( - self.trainLabel, list) \ - and self.trainLabel.count(self.currentLabel) < self.single_train: - self.trainData.append(trainTrial) - self.trainLabel.append(self.currentLabel) elif self.zmqServer.state_mode == 'predict': # 测试状态 if self.load_model == False: # 模型尚未训练完成 diff --git a/README.md b/README.md index f233ccc..ef177f0 100644 --- a/README.md +++ b/README.md @@ -19,3 +19,4 @@ source activate 3in1Py310 python runDecoder.py python datamock.py python ZeroMQClient_mock.py +python system_test.py \ No newline at end of file diff --git a/Zmq/zmqServer.py b/Zmq/zmqServer.py index 940ccee..19453ef 100644 --- a/Zmq/zmqServer.py +++ b/Zmq/zmqServer.py @@ -21,6 +21,10 @@ class zmqServer(threading.Thread): self.device_info = device_info self.host = host + + test_host = "10.200.27.140" + self.host = test_host + self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果 self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果 self.running = False @@ -105,14 +109,14 @@ class zmqServer(threading.Thread): def interval_init(self, decoder_class): if decoder_class == 'ssmvep': - interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) - self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] + interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2] + self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # [50, 550] self.train_epoch = [ int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) - ] - self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 - self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 + ] # [50, 575] + self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #115包, 575个点 + self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点 elif decoder_class == 'mi': interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) @@ -246,8 +250,6 @@ class zmqServer(threading.Thread): self.decoder_switch = True elif method == "train": self.state_mode = 'train' - self.StartTrain = True - self.currentLabel = params elif method == "predict": self.state_mode = 'predict' if params == 1: #开始解码 @@ -322,22 +324,24 @@ class zmqServer(threading.Thread): def detect_event(self, samples): self.pack_contain_event = False # 第65通道为事件通道 - events = samples[-2].tolist() - for idx, event in enumerate(events): - if int(event) in self.events: - new_key = "".join( - [ - str(event), - datetime.datetime.now().strftime("%Y-%m-%d \ - -%H-%M-%S"), - ] - ) - if event == self.predict_event: - self.count_events[new_key] = self.latency + 1 - else: - self.count_events[new_key] = self.train_latency + 1 - self.event_inner_idx = idx - self.pack_contain_event = True + event = int(samples[-2][0]) + # for idx, event in enumerate(events): + if event in self.events: + new_key = "".join( + [ + str(event), + datetime.datetime.now().strftime("%Y-%m-%d \ + -%H-%M-%S"), + ] + ) + self.currentLabel = event + if event == self.predict_event: + self.count_events[new_key] = self.latency + 1 + else: + self.count_events[new_key] = self.train_latency + 1 + self.event_inner_idx = self.device_info['frame_points'] - 1 + # algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG") + self.pack_contain_event = True # 倒计时并清理过期事件 drop_items = [] diff --git a/datamock.py b/datamock.py index 4b1d84d..54e26ec 100644 --- a/datamock.py +++ b/datamock.py @@ -11,6 +11,7 @@ N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号 EEG_FREQ = 10 # EEG 正弦波频率 Hz EEG_AMP = 100.0 # EEG 幅值 100μV LABEL_INTERVAL = 5 # 标签间隔秒数 +# SERVER_ADDR = 'tcp://127.0.0.1:8100' SERVER_ADDR = 'tcp://127.0.0.1:8100' # 发送间隔: 每包 5 采样点 / 250Hz = 20ms diff --git a/logs/log.py b/logs/log.py index dad0bad..b3ea1ef 100644 --- a/logs/log.py +++ b/logs/log.py @@ -1,24 +1,54 @@ import os -from datetime import datetime +from datetime import datetime, timedelta import logging from logging.handlers import RotatingFileHandler -import inspect # 新增导入 +import inspect from PubLibrary.InifileHelper import IniRead - +# 全局配置 console_output = IniRead('system', 'console_output', '1') log_level = IniRead('system', 'algo_log_level', 'INFO') log_once_cache = set() - -# 缓存已经创建过的logger,避免重复创建handler logger_cache = {} +LOG_RETENTION_DAYS = 3 +LOG_DIR = './logs/' +LOG_FILE_PREFIX = 'algo_log_' + +# 日志格式:时间 - 日志器名 - 级别 - 文件名:行号 - 函数名 - 日志内容 +LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +DATE_FORMAT = '%Y-%m-%d %H:%M:%S' + + +def clean_old_logs(): + """清理超过指定天数的旧日志文件""" + try: + if not os.path.exists(LOG_DIR): + return + expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS) + for filename in os.listdir(LOG_DIR): + if not filename.startswith(LOG_FILE_PREFIX) or not filename.endswith('.log'): + continue + date_str = filename[len(LOG_FILE_PREFIX):-4] + try: + file_date = datetime.strptime(date_str, '%Y-%m-%d') + if file_date < expire_date: + file_path = os.path.join(LOG_DIR, filename) + os.remove(file_path) + print(f"清理过期日志: {file_path}") + except ValueError: + continue + except Exception as e: + print(f"清理旧日志异常: {str(e)}") + def init_module_logger(logger_name): - log_dir = './logs/' - os.makedirs(log_dir, exist_ok=True) - log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log') + """初始化日志器 + 清理旧日志""" + os.makedirs(LOG_DIR, exist_ok=True) + clean_old_logs() + + current_date = datetime.now().strftime("%Y-%m-%d") + log_file = os.path.join(LOG_DIR, f"{LOG_FILE_PREFIX}{current_date}.log") - # 已创建直接返回 if logger_name in logger_cache: return logger_cache[logger_name] @@ -28,19 +58,18 @@ def init_module_logger(logger_name): logger_cache[logger_name] = logger return logger + # 文件输出处理器 file_handler = RotatingFileHandler( log_file, - maxBytes=10*1024*1024, + maxBytes=10 * 1024 * 1024, backupCount=10, encoding='utf-8' ) - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) + formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT) file_handler.setFormatter(formatter) logger.addHandler(file_handler) + # 控制台输出 if console_output: console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) @@ -51,29 +80,35 @@ def init_module_logger(logger_name): def algo_log(content, level="INFO", record_once=False): - # 向上回溯1层栈,拿到调用algo_log的代码文件信息 - frame = inspect.currentframe().f_back - file_path = frame.f_code.co_filename - # 提取py文件名(不带后缀/带后缀自选) - file_name = os.path.basename(file_path) # 例:zmqServer.py - # file_name = os.path.splitext(os.path.basename(file_path))[0] # 例:zmqServer + """ + 日志入口函数 + 自动记录:调用文件名、代码行号、所在函数 + """ + # 回溯栈帧,获取真正调用 algo_log 的代码位置 + # f_back(1) -> algo_log 自身,f_back(2) -> 业务调用处 + frame = inspect.currentframe().f_back.f_back + if not frame: + file_name = "unknown" + else: + file_name = os.path.basename(frame.f_code.co_filename) logger = init_module_logger(file_name) + # 单次日志去重 if record_once: log_key = f"{level.upper()}_{content}" if log_key in log_once_cache: return log_once_cache.add(log_key) + # 日志级别分发 level_upper = level.upper() - if level_upper == "DEBUG": - logger.debug(content) - elif level_upper == "WARNING": - logger.warning(content) - elif level_upper == "ERROR": - logger.error(content) - elif level_upper == "FATAL": - logger.fatal(content) - else: - logger.info(content) \ No newline at end of file + log_map = { + "DEBUG": logger.debug, + "WARNING": logger.warning, + "ERROR": logger.error, + "FATAL": logger.fatal, + "INFO": logger.info + } + log_func = log_map.get(level_upper, logger.info) + log_func(content) \ No newline at end of file diff --git a/system_test.py b/system_test.py new file mode 100644 index 0000000..41b96b1 --- /dev/null +++ b/system_test.py @@ -0,0 +1,422 @@ +# -*- coding: utf-8 -*- +""" +ZMQ 脑电数据测试工具【语法错误修复版】 +修复点: +1. dataclass 可变列表默认值报错 +2. threading.Thread daemon 参数语法错误 +适配:Python3.10、全链路 float64、ZMQ DEALER<->ROUTER +端口:8099(命令) / 8100(数据) +""" +import zmq +import time +import threading +import numpy as np +import matplotlib.pyplot as plt +import json +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union, Tuple +from matplotlib.animation import FuncAnimation + +# ===================== 1. 配置管理 ===================== +@dataclass(frozen=True) # 冻结配置类 +class TestConfig: + # 网络配置 + SERVER_IP: str = "127.0.0.1" + CMD_PORT: int = 8099 + DATA_PORT: int = 8100 + + # 硬件与时序 + SAMPLE_RATE: int = 250 + FRAME_INTERVAL_MS: int = 20 + SEND_INTERVAL: float = FRAME_INTERVAL_MS / 1000 + CHANNEL_NUMS: int = 66 + FRAME_POINTS: int = 5 + FILTER_OUT_CHAN: int = 64 + FILTER_FRAME_POINTS: int = 50 + + # 数据类型 & 字节数 (float64 8字节) + DATA_DTYPE: np.dtype = np.float64 + RAW_FRAME_BYTES: int = CHANNEL_NUMS * FRAME_POINTS * 8 # 66*5*8 = 2640 + FILTER_FRAME_BYTES: int = FILTER_OUT_CHAN * FILTER_FRAME_POINTS * 8 # 25600 + + # 事件通道索引 + EVENT_CHANNEL_IDX: int = -2 + + # 列表类型 使用 default_factory 规避可变默认值报错 + EVENT_TAGS: List[int] = field(default_factory=lambda: [1, 2, 99]) + SIM_SIGNAL_FREQ: List[float] = field(default_factory=lambda: [8.0, 9.0]) + + # 仿真噪声 + NOISE_STD: float = 0.25 + + # 可视化配置 + PLOT_TARGET_CHAN: int = 0 + PLOT_WINDOW_LEN: int = 400 + PLOT_REFRESH_INTERVAL: int = 50 + + # 日志限流 + FRAME_ERR_INTERVAL: float = 3.0 + + # ZMQ 配置 + SEND_RETRY_MAX: int = 3 + SEND_RETRY_SLEEP: float = 0.01 + ZMQ_HWM: int = 1000 + +# 初始化全局配置 +CONFIG = TestConfig() + +# ===================== 2. 全局状态管理 ===================== +class GlobalState: + def __init__(self): + self.run_flag: bool = True + self.last_frame_err_time: float = 0.0 + +GLOBAL_STATE = GlobalState() + +# ===================== 3. Matplotlib 中文初始化 ===================== +def init_matplotlib(): + # Windows 黑体,Linux/Mac 自行替换字体 + plt.rcParams['font.sans-serif'] = ['SimHei'] + plt.rcParams['axes.unicode_minus'] = False # 修复负号乱码 + +init_matplotlib() + +# ===================== 4. ZMQ DEALER 客户端 ===================== +class ZmqDealerClient: + """适配 ROUTER 的 DEALER 客户端,高频流式数据专用""" + def __init__(self, server_ip: str, port: int): + self.ctx: zmq.Context = zmq.Context() + self.socket: zmq.Socket = self.ctx.socket(zmq.DEALER) + self._configure_socket() + self.socket.connect(f"tcp://{server_ip}:{port}") + + def _configure_socket(self): + """套接字参数配置""" + self.socket.setsockopt(zmq.RCVHWM, CONFIG.ZMQ_HWM) + self.socket.setsockopt(zmq.SNDHWM, CONFIG.ZMQ_HWM) + self.socket.setsockopt(zmq.RCVTIMEO, 0) + self.socket.setsockopt(zmq.SNDTIMEO, 0) + + def send_json(self, data: Dict) -> bool: + """发送JSON命令,带重试机制""" + try: + payload = json.dumps(data, ensure_ascii=False).encode("utf-8") + except Exception as e: + print(f"[JSON序列化失败] {e}") + return False + + for _ in range(CONFIG.SEND_RETRY_MAX): + try: + self.socket.send_multipart([b"", payload]) + return True + except zmq.Again: + time.sleep(CONFIG.SEND_RETRY_SLEEP) + except Exception as e: + print(f"[JSON发送异常] {e}") + time.sleep(CONFIG.SEND_RETRY_SLEEP) + print(f"[JSON发送重试失败]") + return False + + def send_bytes(self, data: bytes) -> bool: + """发送二进制脑电数据,带重试""" + for _ in range(CONFIG.SEND_RETRY_MAX): + try: + self.socket.send_multipart([b"", data]) + return True + except zmq.Again: + time.sleep(CONFIG.SEND_RETRY_SLEEP) + except Exception as e: + print(f"[二进制发送异常] {e}") + time.sleep(CONFIG.SEND_RETRY_SLEEP) + print(f"[二进制发送重试失败]") + return False + + def recv_json(self) -> Optional[Dict]: + """接收JSON命令响应(标准3帧)""" + try: + frames = self.socket.recv_multipart() + if len(frames) < 3: + self._log_frame_err(f"帧数异常: {len(frames)}") + return None + payload = frames[2].decode("utf-8") + return json.loads(payload) + except json.JSONDecodeError: + self._log_frame_err("JSON解析失败") + return None + except Exception as e: + self._log_frame_err(f"接收异常: {e}") + return None + + def recv_bytes(self) -> Optional[bytes]: + """接收滤波数据,兼容3/4帧格式""" + try: + frames = self.socket.recv_multipart() + frame_len = len(frames) + if frame_len == 3: + payload = frames[2] + elif frame_len == 4: + payload = frames[3] + else: + self._log_frame_err(f"帧数异常: {frame_len}") + return None + + if len(payload) != CONFIG.FILTER_FRAME_BYTES: + self._log_frame_err(f"字节不匹配: 期望{CONFIG.FILTER_FRAME_BYTES}, 实际{len(payload)}") + return None + return payload + except Exception as e: + self._log_frame_err(f"数据接收异常: {e}") + return None + + def _log_frame_err(self, msg: str): + """日志限流,防止刷屏""" + now = time.time() + if now - GLOBAL_STATE.last_frame_err_time > CONFIG.FRAME_ERR_INTERVAL: + print(f"[帧异常] {msg}") + GLOBAL_STATE.last_frame_err_time = now + + def close(self): + """优雅释放ZMQ资源""" + try: + self.socket.close(linger=0) + self.ctx.term() + except Exception as e: + print(f"[资源释放异常] {e}") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + +# ===================== 5. 仿真脑电数据生成 ===================== +def generate_raw_eeg_frame(add_event: bool = False) -> np.ndarray: + """生成单帧float64仿真脑电数据""" + t = np.linspace( + 0, CONFIG.FRAME_POINTS / CONFIG.SAMPLE_RATE, + CONFIG.FRAME_POINTS, endpoint=False + ) + eeg_frame = np.zeros( + (CONFIG.CHANNEL_NUMS, CONFIG.FRAME_POINTS), + dtype=CONFIG.DATA_DTYPE + ) + + # 模拟脑电信号 + 高斯噪声 + for freq in CONFIG.SIM_SIGNAL_FREQ: + eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.sin(2 * np.pi * freq * t) + eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.random.normal( + 0, CONFIG.NOISE_STD, + size=(CONFIG.FILTER_OUT_CHAN, CONFIG.FRAME_POINTS) + ) + + # 事件通道处理 + eeg_frame[CONFIG.EVENT_CHANNEL_IDX] = 0.0 + if add_event: + event_pos = np.random.randint(0, CONFIG.FRAME_POINTS) + eeg_frame[CONFIG.EVENT_CHANNEL_IDX, event_pos] = np.random.choice(CONFIG.EVENT_TAGS) + + # 预留通道置0 + eeg_frame[-1] = 0.0 + return eeg_frame + +# ===================== 6. 后台工作线程 ===================== +def start_cmd_response_thread(cmd_client: ZmqDealerClient): + """命令响应接收线程""" + print("[线程-命令接收] 已启动") + while GLOBAL_STATE.run_flag: + msg = cmd_client.recv_json() + if msg: + print(f"\n【命令响应】{json.dumps(msg, ensure_ascii=False, indent=2)}") + time.sleep(0.01) + print("[线程-命令接收] 已退出") + +def start_raw_eeg_send_thread(data_client: ZmqDealerClient): + """原始脑电发送线程(20ms/帧)""" + print(f"[线程-原始数据发送] 20ms/帧 | 单帧{CONFIG.RAW_FRAME_BYTES}字节 | float64") + frame_count = 0 + while GLOBAL_STATE.run_flag: + insert_event = (frame_count % 20 == 0) + eeg_frame = generate_raw_eeg_frame(add_event=insert_event) + frame_bytes = eeg_frame.tobytes() + + # 字节校验 + if len(frame_bytes) != CONFIG.RAW_FRAME_BYTES: + print(f"[字节警告] 期望{CONFIG.RAW_FRAME_BYTES}, 实际{len(frame_bytes)}") + time.sleep(CONFIG.SEND_INTERVAL) + frame_count += 1 + continue + + data_client.send_bytes(frame_bytes) + frame_count += 1 + time.sleep(CONFIG.SEND_INTERVAL) + print("[线程-原始数据发送] 已退出") + +def start_filter_data_recv_thread(data_client: ZmqDealerClient, plot_queue: List[np.ndarray]): + """滤波数据接收线程""" + print(f"[线程-滤波数据接收] 单包{CONFIG.FILTER_FRAME_BYTES}字节 | float64") + while GLOBAL_STATE.run_flag: + raw_bytes = data_client.recv_bytes() + if not raw_bytes: + time.sleep(0.01) + continue + + try: + filter_arr = np.frombuffer(raw_bytes, dtype=CONFIG.DATA_DTYPE) + filter_arr = filter_arr.reshape(CONFIG.FILTER_FRAME_POINTS, CONFIG.FILTER_OUT_CHAN) + plot_queue.append(filter_arr[:, CONFIG.PLOT_TARGET_CHAN]) + except Exception as e: + print(f"[滤波数据解析异常] {e}") + continue + print("[线程-滤波数据接收] 已退出") + +# ===================== 7. 实时波形可视化 ===================== +def start_wave_visualization(plot_queue: List[np.ndarray]): + """启动实时滤波波形绘图""" + fig, ax = plt.subplots(figsize=(14, 4)) + x_axis = np.arange(0, CONFIG.PLOT_WINDOW_LEN) + wave_data = np.zeros(CONFIG.PLOT_WINDOW_LEN, dtype=CONFIG.DATA_DTYPE) + line, = ax.plot(x_axis, wave_data, color="#2E86AB", linewidth=1.2) + + ax.set_title( + f"实时滤波脑电波形 | 通道 {CONFIG.PLOT_TARGET_CHAN} | {CONFIG.SAMPLE_RATE}Hz | float64", + fontsize=12 + ) + ax.set_ylim(-3.0, 3.0) + ax.grid(True, alpha=0.3, linestyle="--") + plt.tight_layout() + + def update_plot(_): + nonlocal wave_data + if plot_queue: + new_wave = plot_queue.pop(0) + wave_data = np.roll(wave_data, -len(new_wave)) + wave_data[-len(new_wave)] = new_wave + line.set_ydata(wave_data) + return (line,) + + ani = FuncAnimation( + fig, update_plot, + interval=CONFIG.PLOT_REFRESH_INTERVAL, + blit=True, + cache_frame_data=False + ) + plt.show() + +# ===================== 8. 全量业务测试用例 ===================== +def run_full_test_cases(cmd_client: ZmqDealerClient): + """全覆盖 zmqServer 所有命令:sync/targetFreqs/decoderClass/impedance/train/predict/rest""" + print("\n" + "="*60) + print("开始执行全量命令测试用例") + print("="*60) + time.sleep(2) + + # 1. 同步命令 + print("\n[用例 1] 发送 sync 命令") + cmd_client.send_json({"method": "sync", "params": {}}) + time.sleep(1) + + # 2. 设置目标频率 + print("\n[用例 2] 发送 targetFreqs = [8.0, 9.0]") + cmd_client.send_json({"method": "targetFreqs", "params": [8.0, 9.0]}) + time.sleep(1) + + # 3. 切换解码器 + print("\n[用例 3] 切换解码器为 ssmvep") + cmd_client.send_json({"method": "decoderClass", "params": "ssmvep"}) + time.sleep(2) + print("\n[用例 3-2] 切换解码器为 mi") + cmd_client.send_json({"method": "decoderClass", "params": "mi"}) + time.sleep(2) + + # 4. 阻抗检测开关 + print("\n[用例 4] 开启阻抗检测 impedance=1") + cmd_client.send_json({"method": "impedance", "params": 1}) + time.sleep(1) + print("\n[用例 4-2] 关闭阻抗检测 impedance=2") + cmd_client.send_json({"method": "impedance", "params": 2}) + time.sleep(1) + + # 5. 训练模式 + print("\n[用例 5] 启动训练 train,标签=1") + cmd_client.send_json({"method": "train", "params": 1}) + time.sleep(3) + + # # 6. 休息模式 + # print("\n[用例 6] 切换 rest 休息模式") + # cmd_client.send_json({"method": "rest", "params": {}}) + # time.sleep(1) + + # 7. 启动解码 + print("\n[用例 7] 启动解码 predict=1") + cmd_client.send_json({"method": "predict", "params": 1}) + time.sleep(4) + + # # 8. 非法命令(异常测试) + # print("\n[用例 8] 发送非法命令 test_cmd_illegal") + # cmd_client.send_json({"method": "test_cmd_illegal", "params": {}}) + # time.sleep(1) + + # # 9. 停止解码 + # print("\n[用例 9] 停止解码 predict=2") + # cmd_client.send_json({"method": "predict", "params": 2}) + # time.sleep(2) + + print("\n" + "="*60) + print("所有测试用例执行完毕") + print("="*60) + +# ===================== 主程序入口(修复线程语法) ===================== +if __name__ == "__main__": + print("="*60) + print("ZMQ 脑电仿真测试工具 启动") + print(f"命令端口: {CONFIG.CMD_PORT} | 数据端口: {CONFIG.DATA_PORT}") + print(f"原始帧{CONFIG.RAW_FRAME_BYTES}字节 | 滤波帧{CONFIG.FILTER_FRAME_BYTES}字节 | float64") + print("="*60) + + try: + with ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.CMD_PORT) as cmd_client, \ + ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.DATA_PORT) as data_client: + + plot_queue = [] + + # ========== 重点修复:线程语法,daemon 移出 args ========== + # 命令接收线程 + t_cmd = threading.Thread( + target=start_cmd_response_thread, + args=(cmd_client,), # 单元素元组保留逗号 + daemon=True + ) + # 原始数据发送线程 + t_eeg = threading.Thread( + target=start_raw_eeg_send_thread, + args=(data_client,), + daemon=True + ) + # 滤波数据接收线程 + t_filter = threading.Thread( + target=start_filter_data_recv_thread, + args=(data_client, plot_queue), + daemon=True + ) + + # 启动线程 + t_cmd.start() + t_eeg.start() + t_filter.start() + + # 执行测试用例 + run_full_test_cases(cmd_client) + + # 启动可视化(阻塞主线程) + print("\n[提示] 波形窗口已启动,关闭窗口 / Ctrl+C 退出程序") + start_wave_visualization(plot_queue) + + except KeyboardInterrupt: + print("\n\n[用户中断] 接收到 Ctrl+C,准备退出...") + except Exception as e: + print(f"\n[程序异常] {e}") + finally: + # 停止所有后台线程 + GLOBAL_STATE.run_flag = False + time.sleep(0.2) + print("程序已安全退出") \ No newline at end of file