del train
This commit is contained in:
39
Decoder.py
39
Decoder.py
@@ -96,7 +96,7 @@ class Decoder_main(threading.Thread):
|
||||
elif decoder_class == 'ssmvep':
|
||||
self.zmqServer.interval_init(decoder_class)
|
||||
self.n_chan = 8
|
||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
|
||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||
self.single_train = 10 # 单类别数量
|
||||
self.num_target = 2 # 分类目标数目
|
||||
@@ -268,26 +268,29 @@ class Decoder_main(threading.Thread):
|
||||
|
||||
'''训练阶段采集数据'''
|
||||
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||
if self.zmqServer.StartTrain:
|
||||
|
||||
|
||||
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||
self.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
self.zmqServer.StartTrain = False
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.train_epoch[1] \
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
|
||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||
|
||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||
self.trainLabel, list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
else:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||
self.trainLabel, list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
|
||||
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
||||
if self.load_model == False: # 模型尚未训练完成
|
||||
|
||||
@@ -19,3 +19,4 @@ source activate 3in1Py310
|
||||
python runDecoder.py
|
||||
python datamock.py
|
||||
python ZeroMQClient_mock.py
|
||||
python system_test.py
|
||||
@@ -21,6 +21,10 @@ class zmqServer(threading.Thread):
|
||||
self.device_info = device_info
|
||||
|
||||
self.host = host
|
||||
|
||||
test_host = "10.200.27.140"
|
||||
self.host = test_host
|
||||
|
||||
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
||||
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
||||
self.running = False
|
||||
@@ -105,14 +109,14 @@ class zmqServer(threading.Thread):
|
||||
|
||||
def interval_init(self, decoder_class):
|
||||
if decoder_class == 'ssmvep':
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch]
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
|
||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # [50, 550]
|
||||
self.train_epoch = [
|
||||
int(self.interval_epoch[0]),
|
||||
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])
|
||||
]
|
||||
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
||||
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
||||
] # [50, 575]
|
||||
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #115包, 575个点
|
||||
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点
|
||||
|
||||
elif decoder_class == 'mi':
|
||||
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||
@@ -246,8 +250,6 @@ class zmqServer(threading.Thread):
|
||||
self.decoder_switch = True
|
||||
elif method == "train":
|
||||
self.state_mode = 'train'
|
||||
self.StartTrain = True
|
||||
self.currentLabel = params
|
||||
elif method == "predict":
|
||||
self.state_mode = 'predict'
|
||||
if params == 1: #开始解码
|
||||
@@ -322,22 +324,24 @@ class zmqServer(threading.Thread):
|
||||
def detect_event(self, samples):
|
||||
self.pack_contain_event = False
|
||||
# 第65通道为事件通道
|
||||
events = samples[-2].tolist()
|
||||
for idx, event in enumerate(events):
|
||||
if int(event) in self.events:
|
||||
new_key = "".join(
|
||||
[
|
||||
str(event),
|
||||
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||
-%H-%M-%S"),
|
||||
]
|
||||
)
|
||||
if event == self.predict_event:
|
||||
self.count_events[new_key] = self.latency + 1
|
||||
else:
|
||||
self.count_events[new_key] = self.train_latency + 1
|
||||
self.event_inner_idx = idx
|
||||
self.pack_contain_event = True
|
||||
event = int(samples[-2][0])
|
||||
# for idx, event in enumerate(events):
|
||||
if event in self.events:
|
||||
new_key = "".join(
|
||||
[
|
||||
str(event),
|
||||
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||
-%H-%M-%S"),
|
||||
]
|
||||
)
|
||||
self.currentLabel = event
|
||||
if event == self.predict_event:
|
||||
self.count_events[new_key] = self.latency + 1
|
||||
else:
|
||||
self.count_events[new_key] = self.train_latency + 1
|
||||
self.event_inner_idx = self.device_info['frame_points'] - 1
|
||||
# algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG")
|
||||
self.pack_contain_event = True
|
||||
|
||||
# 倒计时并清理过期事件
|
||||
drop_items = []
|
||||
|
||||
@@ -11,6 +11,7 @@ N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
|
||||
EEG_FREQ = 10 # EEG 正弦波频率 Hz
|
||||
EEG_AMP = 100.0 # EEG 幅值 100μV
|
||||
LABEL_INTERVAL = 5 # 标签间隔秒数
|
||||
# SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
||||
SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
||||
|
||||
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
|
||||
|
||||
95
logs/log.py
95
logs/log.py
@@ -1,24 +1,54 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import inspect # 新增导入
|
||||
import inspect
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
|
||||
|
||||
# 全局配置
|
||||
console_output = IniRead('system', 'console_output', '1')
|
||||
log_level = IniRead('system', 'algo_log_level', 'INFO')
|
||||
log_once_cache = set()
|
||||
|
||||
# 缓存已经创建过的logger,避免重复创建handler
|
||||
logger_cache = {}
|
||||
LOG_RETENTION_DAYS = 3
|
||||
LOG_DIR = './logs/'
|
||||
LOG_FILE_PREFIX = 'algo_log_'
|
||||
|
||||
# 日志格式:时间 - 日志器名 - 级别 - 文件名:行号 - 函数名 - 日志内容
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
|
||||
def clean_old_logs():
|
||||
"""清理超过指定天数的旧日志文件"""
|
||||
try:
|
||||
if not os.path.exists(LOG_DIR):
|
||||
return
|
||||
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
|
||||
for filename in os.listdir(LOG_DIR):
|
||||
if not filename.startswith(LOG_FILE_PREFIX) or not filename.endswith('.log'):
|
||||
continue
|
||||
date_str = filename[len(LOG_FILE_PREFIX):-4]
|
||||
try:
|
||||
file_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||
if file_date < expire_date:
|
||||
file_path = os.path.join(LOG_DIR, filename)
|
||||
os.remove(file_path)
|
||||
print(f"清理过期日志: {file_path}")
|
||||
except ValueError:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"清理旧日志异常: {str(e)}")
|
||||
|
||||
|
||||
def init_module_logger(logger_name):
|
||||
log_dir = './logs/'
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log')
|
||||
"""初始化日志器 + 清理旧日志"""
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
clean_old_logs()
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
log_file = os.path.join(LOG_DIR, f"{LOG_FILE_PREFIX}{current_date}.log")
|
||||
|
||||
# 已创建直接返回
|
||||
if logger_name in logger_cache:
|
||||
return logger_cache[logger_name]
|
||||
|
||||
@@ -28,19 +58,18 @@ def init_module_logger(logger_name):
|
||||
logger_cache[logger_name] = logger
|
||||
return logger
|
||||
|
||||
# 文件输出处理器
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=10*1024*1024,
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=10,
|
||||
encoding='utf-8'
|
||||
)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 控制台输出
|
||||
if console_output:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
@@ -51,29 +80,35 @@ def init_module_logger(logger_name):
|
||||
|
||||
|
||||
def algo_log(content, level="INFO", record_once=False):
|
||||
# 向上回溯1层栈,拿到调用algo_log的代码文件信息
|
||||
frame = inspect.currentframe().f_back
|
||||
file_path = frame.f_code.co_filename
|
||||
# 提取py文件名(不带后缀/带后缀自选)
|
||||
file_name = os.path.basename(file_path) # 例:zmqServer.py
|
||||
# file_name = os.path.splitext(os.path.basename(file_path))[0] # 例:zmqServer
|
||||
"""
|
||||
日志入口函数
|
||||
自动记录:调用文件名、代码行号、所在函数
|
||||
"""
|
||||
# 回溯栈帧,获取真正调用 algo_log 的代码位置
|
||||
# f_back(1) -> algo_log 自身,f_back(2) -> 业务调用处
|
||||
frame = inspect.currentframe().f_back.f_back
|
||||
if not frame:
|
||||
file_name = "unknown"
|
||||
else:
|
||||
file_name = os.path.basename(frame.f_code.co_filename)
|
||||
|
||||
logger = init_module_logger(file_name)
|
||||
|
||||
# 单次日志去重
|
||||
if record_once:
|
||||
log_key = f"{level.upper()}_{content}"
|
||||
if log_key in log_once_cache:
|
||||
return
|
||||
log_once_cache.add(log_key)
|
||||
|
||||
# 日志级别分发
|
||||
level_upper = level.upper()
|
||||
if level_upper == "DEBUG":
|
||||
logger.debug(content)
|
||||
elif level_upper == "WARNING":
|
||||
logger.warning(content)
|
||||
elif level_upper == "ERROR":
|
||||
logger.error(content)
|
||||
elif level_upper == "FATAL":
|
||||
logger.fatal(content)
|
||||
else:
|
||||
logger.info(content)
|
||||
log_map = {
|
||||
"DEBUG": logger.debug,
|
||||
"WARNING": logger.warning,
|
||||
"ERROR": logger.error,
|
||||
"FATAL": logger.fatal,
|
||||
"INFO": logger.info
|
||||
}
|
||||
log_func = log_map.get(level_upper, logger.info)
|
||||
log_func(content)
|
||||
422
system_test.py
Normal file
422
system_test.py
Normal file
@@ -0,0 +1,422 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
ZMQ 脑电数据测试工具【语法错误修复版】
|
||||
修复点:
|
||||
1. dataclass 可变列表默认值报错
|
||||
2. threading.Thread daemon 参数语法错误
|
||||
适配:Python3.10、全链路 float64、ZMQ DEALER<->ROUTER
|
||||
端口:8099(命令) / 8100(数据)
|
||||
"""
|
||||
import zmq
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Union, Tuple
|
||||
from matplotlib.animation import FuncAnimation
|
||||
|
||||
# ===================== 1. 配置管理 =====================
|
||||
@dataclass(frozen=True) # 冻结配置类
|
||||
class TestConfig:
|
||||
# 网络配置
|
||||
SERVER_IP: str = "127.0.0.1"
|
||||
CMD_PORT: int = 8099
|
||||
DATA_PORT: int = 8100
|
||||
|
||||
# 硬件与时序
|
||||
SAMPLE_RATE: int = 250
|
||||
FRAME_INTERVAL_MS: int = 20
|
||||
SEND_INTERVAL: float = FRAME_INTERVAL_MS / 1000
|
||||
CHANNEL_NUMS: int = 66
|
||||
FRAME_POINTS: int = 5
|
||||
FILTER_OUT_CHAN: int = 64
|
||||
FILTER_FRAME_POINTS: int = 50
|
||||
|
||||
# 数据类型 & 字节数 (float64 8字节)
|
||||
DATA_DTYPE: np.dtype = np.float64
|
||||
RAW_FRAME_BYTES: int = CHANNEL_NUMS * FRAME_POINTS * 8 # 66*5*8 = 2640
|
||||
FILTER_FRAME_BYTES: int = FILTER_OUT_CHAN * FILTER_FRAME_POINTS * 8 # 25600
|
||||
|
||||
# 事件通道索引
|
||||
EVENT_CHANNEL_IDX: int = -2
|
||||
|
||||
# 列表类型 使用 default_factory 规避可变默认值报错
|
||||
EVENT_TAGS: List[int] = field(default_factory=lambda: [1, 2, 99])
|
||||
SIM_SIGNAL_FREQ: List[float] = field(default_factory=lambda: [8.0, 9.0])
|
||||
|
||||
# 仿真噪声
|
||||
NOISE_STD: float = 0.25
|
||||
|
||||
# 可视化配置
|
||||
PLOT_TARGET_CHAN: int = 0
|
||||
PLOT_WINDOW_LEN: int = 400
|
||||
PLOT_REFRESH_INTERVAL: int = 50
|
||||
|
||||
# 日志限流
|
||||
FRAME_ERR_INTERVAL: float = 3.0
|
||||
|
||||
# ZMQ 配置
|
||||
SEND_RETRY_MAX: int = 3
|
||||
SEND_RETRY_SLEEP: float = 0.01
|
||||
ZMQ_HWM: int = 1000
|
||||
|
||||
# 初始化全局配置
|
||||
CONFIG = TestConfig()
|
||||
|
||||
# ===================== 2. 全局状态管理 =====================
|
||||
class GlobalState:
|
||||
def __init__(self):
|
||||
self.run_flag: bool = True
|
||||
self.last_frame_err_time: float = 0.0
|
||||
|
||||
GLOBAL_STATE = GlobalState()
|
||||
|
||||
# ===================== 3. Matplotlib 中文初始化 =====================
|
||||
def init_matplotlib():
|
||||
# Windows 黑体,Linux/Mac 自行替换字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||
plt.rcParams['axes.unicode_minus'] = False # 修复负号乱码
|
||||
|
||||
init_matplotlib()
|
||||
|
||||
# ===================== 4. ZMQ DEALER 客户端 =====================
|
||||
class ZmqDealerClient:
|
||||
"""适配 ROUTER 的 DEALER 客户端,高频流式数据专用"""
|
||||
def __init__(self, server_ip: str, port: int):
|
||||
self.ctx: zmq.Context = zmq.Context()
|
||||
self.socket: zmq.Socket = self.ctx.socket(zmq.DEALER)
|
||||
self._configure_socket()
|
||||
self.socket.connect(f"tcp://{server_ip}:{port}")
|
||||
|
||||
def _configure_socket(self):
|
||||
"""套接字参数配置"""
|
||||
self.socket.setsockopt(zmq.RCVHWM, CONFIG.ZMQ_HWM)
|
||||
self.socket.setsockopt(zmq.SNDHWM, CONFIG.ZMQ_HWM)
|
||||
self.socket.setsockopt(zmq.RCVTIMEO, 0)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, 0)
|
||||
|
||||
def send_json(self, data: Dict) -> bool:
|
||||
"""发送JSON命令,带重试机制"""
|
||||
try:
|
||||
payload = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"[JSON序列化失败] {e}")
|
||||
return False
|
||||
|
||||
for _ in range(CONFIG.SEND_RETRY_MAX):
|
||||
try:
|
||||
self.socket.send_multipart([b"", payload])
|
||||
return True
|
||||
except zmq.Again:
|
||||
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
||||
except Exception as e:
|
||||
print(f"[JSON发送异常] {e}")
|
||||
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
||||
print(f"[JSON发送重试失败]")
|
||||
return False
|
||||
|
||||
def send_bytes(self, data: bytes) -> bool:
|
||||
"""发送二进制脑电数据,带重试"""
|
||||
for _ in range(CONFIG.SEND_RETRY_MAX):
|
||||
try:
|
||||
self.socket.send_multipart([b"", data])
|
||||
return True
|
||||
except zmq.Again:
|
||||
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
||||
except Exception as e:
|
||||
print(f"[二进制发送异常] {e}")
|
||||
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
||||
print(f"[二进制发送重试失败]")
|
||||
return False
|
||||
|
||||
def recv_json(self) -> Optional[Dict]:
|
||||
"""接收JSON命令响应(标准3帧)"""
|
||||
try:
|
||||
frames = self.socket.recv_multipart()
|
||||
if len(frames) < 3:
|
||||
self._log_frame_err(f"帧数异常: {len(frames)}")
|
||||
return None
|
||||
payload = frames[2].decode("utf-8")
|
||||
return json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
self._log_frame_err("JSON解析失败")
|
||||
return None
|
||||
except Exception as e:
|
||||
self._log_frame_err(f"接收异常: {e}")
|
||||
return None
|
||||
|
||||
def recv_bytes(self) -> Optional[bytes]:
|
||||
"""接收滤波数据,兼容3/4帧格式"""
|
||||
try:
|
||||
frames = self.socket.recv_multipart()
|
||||
frame_len = len(frames)
|
||||
if frame_len == 3:
|
||||
payload = frames[2]
|
||||
elif frame_len == 4:
|
||||
payload = frames[3]
|
||||
else:
|
||||
self._log_frame_err(f"帧数异常: {frame_len}")
|
||||
return None
|
||||
|
||||
if len(payload) != CONFIG.FILTER_FRAME_BYTES:
|
||||
self._log_frame_err(f"字节不匹配: 期望{CONFIG.FILTER_FRAME_BYTES}, 实际{len(payload)}")
|
||||
return None
|
||||
return payload
|
||||
except Exception as e:
|
||||
self._log_frame_err(f"数据接收异常: {e}")
|
||||
return None
|
||||
|
||||
def _log_frame_err(self, msg: str):
|
||||
"""日志限流,防止刷屏"""
|
||||
now = time.time()
|
||||
if now - GLOBAL_STATE.last_frame_err_time > CONFIG.FRAME_ERR_INTERVAL:
|
||||
print(f"[帧异常] {msg}")
|
||||
GLOBAL_STATE.last_frame_err_time = now
|
||||
|
||||
def close(self):
|
||||
"""优雅释放ZMQ资源"""
|
||||
try:
|
||||
self.socket.close(linger=0)
|
||||
self.ctx.term()
|
||||
except Exception as e:
|
||||
print(f"[资源释放异常] {e}")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
# ===================== 5. 仿真脑电数据生成 =====================
|
||||
def generate_raw_eeg_frame(add_event: bool = False) -> np.ndarray:
|
||||
"""生成单帧float64仿真脑电数据"""
|
||||
t = np.linspace(
|
||||
0, CONFIG.FRAME_POINTS / CONFIG.SAMPLE_RATE,
|
||||
CONFIG.FRAME_POINTS, endpoint=False
|
||||
)
|
||||
eeg_frame = np.zeros(
|
||||
(CONFIG.CHANNEL_NUMS, CONFIG.FRAME_POINTS),
|
||||
dtype=CONFIG.DATA_DTYPE
|
||||
)
|
||||
|
||||
# 模拟脑电信号 + 高斯噪声
|
||||
for freq in CONFIG.SIM_SIGNAL_FREQ:
|
||||
eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.sin(2 * np.pi * freq * t)
|
||||
eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.random.normal(
|
||||
0, CONFIG.NOISE_STD,
|
||||
size=(CONFIG.FILTER_OUT_CHAN, CONFIG.FRAME_POINTS)
|
||||
)
|
||||
|
||||
# 事件通道处理
|
||||
eeg_frame[CONFIG.EVENT_CHANNEL_IDX] = 0.0
|
||||
if add_event:
|
||||
event_pos = np.random.randint(0, CONFIG.FRAME_POINTS)
|
||||
eeg_frame[CONFIG.EVENT_CHANNEL_IDX, event_pos] = np.random.choice(CONFIG.EVENT_TAGS)
|
||||
|
||||
# 预留通道置0
|
||||
eeg_frame[-1] = 0.0
|
||||
return eeg_frame
|
||||
|
||||
# ===================== 6. 后台工作线程 =====================
|
||||
def start_cmd_response_thread(cmd_client: ZmqDealerClient):
|
||||
"""命令响应接收线程"""
|
||||
print("[线程-命令接收] 已启动")
|
||||
while GLOBAL_STATE.run_flag:
|
||||
msg = cmd_client.recv_json()
|
||||
if msg:
|
||||
print(f"\n【命令响应】{json.dumps(msg, ensure_ascii=False, indent=2)}")
|
||||
time.sleep(0.01)
|
||||
print("[线程-命令接收] 已退出")
|
||||
|
||||
def start_raw_eeg_send_thread(data_client: ZmqDealerClient):
|
||||
"""原始脑电发送线程(20ms/帧)"""
|
||||
print(f"[线程-原始数据发送] 20ms/帧 | 单帧{CONFIG.RAW_FRAME_BYTES}字节 | float64")
|
||||
frame_count = 0
|
||||
while GLOBAL_STATE.run_flag:
|
||||
insert_event = (frame_count % 20 == 0)
|
||||
eeg_frame = generate_raw_eeg_frame(add_event=insert_event)
|
||||
frame_bytes = eeg_frame.tobytes()
|
||||
|
||||
# 字节校验
|
||||
if len(frame_bytes) != CONFIG.RAW_FRAME_BYTES:
|
||||
print(f"[字节警告] 期望{CONFIG.RAW_FRAME_BYTES}, 实际{len(frame_bytes)}")
|
||||
time.sleep(CONFIG.SEND_INTERVAL)
|
||||
frame_count += 1
|
||||
continue
|
||||
|
||||
data_client.send_bytes(frame_bytes)
|
||||
frame_count += 1
|
||||
time.sleep(CONFIG.SEND_INTERVAL)
|
||||
print("[线程-原始数据发送] 已退出")
|
||||
|
||||
def start_filter_data_recv_thread(data_client: ZmqDealerClient, plot_queue: List[np.ndarray]):
|
||||
"""滤波数据接收线程"""
|
||||
print(f"[线程-滤波数据接收] 单包{CONFIG.FILTER_FRAME_BYTES}字节 | float64")
|
||||
while GLOBAL_STATE.run_flag:
|
||||
raw_bytes = data_client.recv_bytes()
|
||||
if not raw_bytes:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
try:
|
||||
filter_arr = np.frombuffer(raw_bytes, dtype=CONFIG.DATA_DTYPE)
|
||||
filter_arr = filter_arr.reshape(CONFIG.FILTER_FRAME_POINTS, CONFIG.FILTER_OUT_CHAN)
|
||||
plot_queue.append(filter_arr[:, CONFIG.PLOT_TARGET_CHAN])
|
||||
except Exception as e:
|
||||
print(f"[滤波数据解析异常] {e}")
|
||||
continue
|
||||
print("[线程-滤波数据接收] 已退出")
|
||||
|
||||
# ===================== 7. 实时波形可视化 =====================
|
||||
def start_wave_visualization(plot_queue: List[np.ndarray]):
|
||||
"""启动实时滤波波形绘图"""
|
||||
fig, ax = plt.subplots(figsize=(14, 4))
|
||||
x_axis = np.arange(0, CONFIG.PLOT_WINDOW_LEN)
|
||||
wave_data = np.zeros(CONFIG.PLOT_WINDOW_LEN, dtype=CONFIG.DATA_DTYPE)
|
||||
line, = ax.plot(x_axis, wave_data, color="#2E86AB", linewidth=1.2)
|
||||
|
||||
ax.set_title(
|
||||
f"实时滤波脑电波形 | 通道 {CONFIG.PLOT_TARGET_CHAN} | {CONFIG.SAMPLE_RATE}Hz | float64",
|
||||
fontsize=12
|
||||
)
|
||||
ax.set_ylim(-3.0, 3.0)
|
||||
ax.grid(True, alpha=0.3, linestyle="--")
|
||||
plt.tight_layout()
|
||||
|
||||
def update_plot(_):
|
||||
nonlocal wave_data
|
||||
if plot_queue:
|
||||
new_wave = plot_queue.pop(0)
|
||||
wave_data = np.roll(wave_data, -len(new_wave))
|
||||
wave_data[-len(new_wave)] = new_wave
|
||||
line.set_ydata(wave_data)
|
||||
return (line,)
|
||||
|
||||
ani = FuncAnimation(
|
||||
fig, update_plot,
|
||||
interval=CONFIG.PLOT_REFRESH_INTERVAL,
|
||||
blit=True,
|
||||
cache_frame_data=False
|
||||
)
|
||||
plt.show()
|
||||
|
||||
# ===================== 8. 全量业务测试用例 =====================
|
||||
def run_full_test_cases(cmd_client: ZmqDealerClient):
|
||||
"""全覆盖 zmqServer 所有命令:sync/targetFreqs/decoderClass/impedance/train/predict/rest"""
|
||||
print("\n" + "="*60)
|
||||
print("开始执行全量命令测试用例")
|
||||
print("="*60)
|
||||
time.sleep(2)
|
||||
|
||||
# 1. 同步命令
|
||||
print("\n[用例 1] 发送 sync 命令")
|
||||
cmd_client.send_json({"method": "sync", "params": {}})
|
||||
time.sleep(1)
|
||||
|
||||
# 2. 设置目标频率
|
||||
print("\n[用例 2] 发送 targetFreqs = [8.0, 9.0]")
|
||||
cmd_client.send_json({"method": "targetFreqs", "params": [8.0, 9.0]})
|
||||
time.sleep(1)
|
||||
|
||||
# 3. 切换解码器
|
||||
print("\n[用例 3] 切换解码器为 ssmvep")
|
||||
cmd_client.send_json({"method": "decoderClass", "params": "ssmvep"})
|
||||
time.sleep(2)
|
||||
print("\n[用例 3-2] 切换解码器为 mi")
|
||||
cmd_client.send_json({"method": "decoderClass", "params": "mi"})
|
||||
time.sleep(2)
|
||||
|
||||
# 4. 阻抗检测开关
|
||||
print("\n[用例 4] 开启阻抗检测 impedance=1")
|
||||
cmd_client.send_json({"method": "impedance", "params": 1})
|
||||
time.sleep(1)
|
||||
print("\n[用例 4-2] 关闭阻抗检测 impedance=2")
|
||||
cmd_client.send_json({"method": "impedance", "params": 2})
|
||||
time.sleep(1)
|
||||
|
||||
# 5. 训练模式
|
||||
print("\n[用例 5] 启动训练 train,标签=1")
|
||||
cmd_client.send_json({"method": "train", "params": 1})
|
||||
time.sleep(3)
|
||||
|
||||
# # 6. 休息模式
|
||||
# print("\n[用例 6] 切换 rest 休息模式")
|
||||
# cmd_client.send_json({"method": "rest", "params": {}})
|
||||
# time.sleep(1)
|
||||
|
||||
# 7. 启动解码
|
||||
print("\n[用例 7] 启动解码 predict=1")
|
||||
cmd_client.send_json({"method": "predict", "params": 1})
|
||||
time.sleep(4)
|
||||
|
||||
# # 8. 非法命令(异常测试)
|
||||
# print("\n[用例 8] 发送非法命令 test_cmd_illegal")
|
||||
# cmd_client.send_json({"method": "test_cmd_illegal", "params": {}})
|
||||
# time.sleep(1)
|
||||
|
||||
# # 9. 停止解码
|
||||
# print("\n[用例 9] 停止解码 predict=2")
|
||||
# cmd_client.send_json({"method": "predict", "params": 2})
|
||||
# time.sleep(2)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("所有测试用例执行完毕")
|
||||
print("="*60)
|
||||
|
||||
# ===================== 主程序入口(修复线程语法) =====================
|
||||
if __name__ == "__main__":
|
||||
print("="*60)
|
||||
print("ZMQ 脑电仿真测试工具 启动")
|
||||
print(f"命令端口: {CONFIG.CMD_PORT} | 数据端口: {CONFIG.DATA_PORT}")
|
||||
print(f"原始帧{CONFIG.RAW_FRAME_BYTES}字节 | 滤波帧{CONFIG.FILTER_FRAME_BYTES}字节 | float64")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
with ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.CMD_PORT) as cmd_client, \
|
||||
ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.DATA_PORT) as data_client:
|
||||
|
||||
plot_queue = []
|
||||
|
||||
# ========== 重点修复:线程语法,daemon 移出 args ==========
|
||||
# 命令接收线程
|
||||
t_cmd = threading.Thread(
|
||||
target=start_cmd_response_thread,
|
||||
args=(cmd_client,), # 单元素元组保留逗号
|
||||
daemon=True
|
||||
)
|
||||
# 原始数据发送线程
|
||||
t_eeg = threading.Thread(
|
||||
target=start_raw_eeg_send_thread,
|
||||
args=(data_client,),
|
||||
daemon=True
|
||||
)
|
||||
# 滤波数据接收线程
|
||||
t_filter = threading.Thread(
|
||||
target=start_filter_data_recv_thread,
|
||||
args=(data_client, plot_queue),
|
||||
daemon=True
|
||||
)
|
||||
|
||||
# 启动线程
|
||||
t_cmd.start()
|
||||
t_eeg.start()
|
||||
t_filter.start()
|
||||
|
||||
# 执行测试用例
|
||||
run_full_test_cases(cmd_client)
|
||||
|
||||
# 启动可视化(阻塞主线程)
|
||||
print("\n[提示] 波形窗口已启动,关闭窗口 / Ctrl+C 退出程序")
|
||||
start_wave_visualization(plot_queue)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[用户中断] 接收到 Ctrl+C,准备退出...")
|
||||
except Exception as e:
|
||||
print(f"\n[程序异常] {e}")
|
||||
finally:
|
||||
# 停止所有后台线程
|
||||
GLOBAL_STATE.run_flag = False
|
||||
time.sleep(0.2)
|
||||
print("程序已安全退出")
|
||||
Reference in New Issue
Block a user