del train
This commit is contained in:
39
Decoder.py
39
Decoder.py
@@ -96,7 +96,7 @@ class Decoder_main(threading.Thread):
|
|||||||
elif decoder_class == 'ssmvep':
|
elif decoder_class == 'ssmvep':
|
||||||
self.zmqServer.interval_init(decoder_class)
|
self.zmqServer.interval_init(decoder_class)
|
||||||
self.n_chan = 8
|
self.n_chan = 8
|
||||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
|
||||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||||
self.single_train = 10 # 单类别数量
|
self.single_train = 10 # 单类别数量
|
||||||
self.num_target = 2 # 分类目标数目
|
self.num_target = 2 # 分类目标数目
|
||||||
@@ -268,26 +268,29 @@ class Decoder_main(threading.Thread):
|
|||||||
|
|
||||||
'''训练阶段采集数据'''
|
'''训练阶段采集数据'''
|
||||||
if self.zmqServer.state_mode == 'train': # 训练状态
|
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||||
if self.zmqServer.StartTrain:
|
|
||||||
|
|
||||||
|
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||||
|
self.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||||
|
|
||||||
self.currentLabel = self.zmqServer.currentLabel
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
self.zmqServer.StartTrain = False
|
|
||||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||||
self.train_epoch[1] \
|
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||||
+ self.zmqServer.event_inner_idx:
|
|
||||||
|
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
||||||
|
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||||
|
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||||
|
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||||
|
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||||
|
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||||
|
self.trainLabel, list) \
|
||||||
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
|
self.trainData.append(trainTrial)
|
||||||
|
self.trainLabel.append(self.currentLabel)
|
||||||
|
else:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
|
||||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
|
||||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
|
||||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
|
||||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
|
||||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
|
||||||
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
|
||||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
|
||||||
self.trainLabel, list) \
|
|
||||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
|
||||||
self.trainData.append(trainTrial)
|
|
||||||
self.trainLabel.append(self.currentLabel)
|
|
||||||
|
|
||||||
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
||||||
if self.load_model == False: # 模型尚未训练完成
|
if self.load_model == False: # 模型尚未训练完成
|
||||||
|
|||||||
@@ -19,3 +19,4 @@ source activate 3in1Py310
|
|||||||
python runDecoder.py
|
python runDecoder.py
|
||||||
python datamock.py
|
python datamock.py
|
||||||
python ZeroMQClient_mock.py
|
python ZeroMQClient_mock.py
|
||||||
|
python system_test.py
|
||||||
@@ -21,6 +21,10 @@ class zmqServer(threading.Thread):
|
|||||||
self.device_info = device_info
|
self.device_info = device_info
|
||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
|
|
||||||
|
test_host = "10.200.27.140"
|
||||||
|
self.host = test_host
|
||||||
|
|
||||||
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
||||||
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
||||||
self.running = False
|
self.running = False
|
||||||
@@ -105,14 +109,14 @@ class zmqServer(threading.Thread):
|
|||||||
|
|
||||||
def interval_init(self, decoder_class):
|
def interval_init(self, decoder_class):
|
||||||
if decoder_class == 'ssmvep':
|
if decoder_class == 'ssmvep':
|
||||||
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
|
||||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch]
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] # [50, 550]
|
||||||
self.train_epoch = [
|
self.train_epoch = [
|
||||||
int(self.interval_epoch[0]),
|
int(self.interval_epoch[0]),
|
||||||
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])
|
int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])
|
||||||
]
|
] # [50, 575]
|
||||||
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
self.latency = (self.interval_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #115包, 575个点
|
||||||
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5
|
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点
|
||||||
|
|
||||||
elif decoder_class == 'mi':
|
elif decoder_class == 'mi':
|
||||||
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||||
@@ -246,8 +250,6 @@ class zmqServer(threading.Thread):
|
|||||||
self.decoder_switch = True
|
self.decoder_switch = True
|
||||||
elif method == "train":
|
elif method == "train":
|
||||||
self.state_mode = 'train'
|
self.state_mode = 'train'
|
||||||
self.StartTrain = True
|
|
||||||
self.currentLabel = params
|
|
||||||
elif method == "predict":
|
elif method == "predict":
|
||||||
self.state_mode = 'predict'
|
self.state_mode = 'predict'
|
||||||
if params == 1: #开始解码
|
if params == 1: #开始解码
|
||||||
@@ -322,22 +324,24 @@ class zmqServer(threading.Thread):
|
|||||||
def detect_event(self, samples):
|
def detect_event(self, samples):
|
||||||
self.pack_contain_event = False
|
self.pack_contain_event = False
|
||||||
# 第65通道为事件通道
|
# 第65通道为事件通道
|
||||||
events = samples[-2].tolist()
|
event = int(samples[-2][0])
|
||||||
for idx, event in enumerate(events):
|
# for idx, event in enumerate(events):
|
||||||
if int(event) in self.events:
|
if event in self.events:
|
||||||
new_key = "".join(
|
new_key = "".join(
|
||||||
[
|
[
|
||||||
str(event),
|
str(event),
|
||||||
datetime.datetime.now().strftime("%Y-%m-%d \
|
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||||
-%H-%M-%S"),
|
-%H-%M-%S"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if event == self.predict_event:
|
self.currentLabel = event
|
||||||
self.count_events[new_key] = self.latency + 1
|
if event == self.predict_event:
|
||||||
else:
|
self.count_events[new_key] = self.latency + 1
|
||||||
self.count_events[new_key] = self.train_latency + 1
|
else:
|
||||||
self.event_inner_idx = idx
|
self.count_events[new_key] = self.train_latency + 1
|
||||||
self.pack_contain_event = True
|
self.event_inner_idx = self.device_info['frame_points'] - 1
|
||||||
|
# algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG")
|
||||||
|
self.pack_contain_event = True
|
||||||
|
|
||||||
# 倒计时并清理过期事件
|
# 倒计时并清理过期事件
|
||||||
drop_items = []
|
drop_items = []
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
|
|||||||
EEG_FREQ = 10 # EEG 正弦波频率 Hz
|
EEG_FREQ = 10 # EEG 正弦波频率 Hz
|
||||||
EEG_AMP = 100.0 # EEG 幅值 100μV
|
EEG_AMP = 100.0 # EEG 幅值 100μV
|
||||||
LABEL_INTERVAL = 5 # 标签间隔秒数
|
LABEL_INTERVAL = 5 # 标签间隔秒数
|
||||||
|
# SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
||||||
SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
||||||
|
|
||||||
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
|
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
|
||||||
|
|||||||
95
logs/log.py
95
logs/log.py
@@ -1,24 +1,54 @@
|
|||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
import inspect # 新增导入
|
import inspect
|
||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
|
# 全局配置
|
||||||
console_output = IniRead('system', 'console_output', '1')
|
console_output = IniRead('system', 'console_output', '1')
|
||||||
log_level = IniRead('system', 'algo_log_level', 'INFO')
|
log_level = IniRead('system', 'algo_log_level', 'INFO')
|
||||||
log_once_cache = set()
|
log_once_cache = set()
|
||||||
|
|
||||||
# 缓存已经创建过的logger,避免重复创建handler
|
|
||||||
logger_cache = {}
|
logger_cache = {}
|
||||||
|
LOG_RETENTION_DAYS = 3
|
||||||
|
LOG_DIR = './logs/'
|
||||||
|
LOG_FILE_PREFIX = 'algo_log_'
|
||||||
|
|
||||||
|
# 日志格式:时间 - 日志器名 - 级别 - 文件名:行号 - 函数名 - 日志内容
|
||||||
|
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
|
||||||
|
|
||||||
|
|
||||||
|
def clean_old_logs():
|
||||||
|
"""清理超过指定天数的旧日志文件"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(LOG_DIR):
|
||||||
|
return
|
||||||
|
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
|
||||||
|
for filename in os.listdir(LOG_DIR):
|
||||||
|
if not filename.startswith(LOG_FILE_PREFIX) or not filename.endswith('.log'):
|
||||||
|
continue
|
||||||
|
date_str = filename[len(LOG_FILE_PREFIX):-4]
|
||||||
|
try:
|
||||||
|
file_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||||
|
if file_date < expire_date:
|
||||||
|
file_path = os.path.join(LOG_DIR, filename)
|
||||||
|
os.remove(file_path)
|
||||||
|
print(f"清理过期日志: {file_path}")
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"清理旧日志异常: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def init_module_logger(logger_name):
|
def init_module_logger(logger_name):
|
||||||
log_dir = './logs/'
|
"""初始化日志器 + 清理旧日志"""
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(LOG_DIR, exist_ok=True)
|
||||||
log_file = os.path.join(log_dir, f'algo_log_{datetime.now().strftime("%Y-%m-%d")}.log')
|
clean_old_logs()
|
||||||
|
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
log_file = os.path.join(LOG_DIR, f"{LOG_FILE_PREFIX}{current_date}.log")
|
||||||
|
|
||||||
# 已创建直接返回
|
|
||||||
if logger_name in logger_cache:
|
if logger_name in logger_cache:
|
||||||
return logger_cache[logger_name]
|
return logger_cache[logger_name]
|
||||||
|
|
||||||
@@ -28,19 +58,18 @@ def init_module_logger(logger_name):
|
|||||||
logger_cache[logger_name] = logger
|
logger_cache[logger_name] = logger
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
# 文件输出处理器
|
||||||
file_handler = RotatingFileHandler(
|
file_handler = RotatingFileHandler(
|
||||||
log_file,
|
log_file,
|
||||||
maxBytes=10*1024*1024,
|
maxBytes=10 * 1024 * 1024,
|
||||||
backupCount=10,
|
backupCount=10,
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# 控制台输出
|
||||||
if console_output:
|
if console_output:
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
@@ -51,29 +80,35 @@ def init_module_logger(logger_name):
|
|||||||
|
|
||||||
|
|
||||||
def algo_log(content, level="INFO", record_once=False):
|
def algo_log(content, level="INFO", record_once=False):
|
||||||
# 向上回溯1层栈,拿到调用algo_log的代码文件信息
|
"""
|
||||||
frame = inspect.currentframe().f_back
|
日志入口函数
|
||||||
file_path = frame.f_code.co_filename
|
自动记录:调用文件名、代码行号、所在函数
|
||||||
# 提取py文件名(不带后缀/带后缀自选)
|
"""
|
||||||
file_name = os.path.basename(file_path) # 例:zmqServer.py
|
# 回溯栈帧,获取真正调用 algo_log 的代码位置
|
||||||
# file_name = os.path.splitext(os.path.basename(file_path))[0] # 例:zmqServer
|
# f_back(1) -> algo_log 自身,f_back(2) -> 业务调用处
|
||||||
|
frame = inspect.currentframe().f_back.f_back
|
||||||
|
if not frame:
|
||||||
|
file_name = "unknown"
|
||||||
|
else:
|
||||||
|
file_name = os.path.basename(frame.f_code.co_filename)
|
||||||
|
|
||||||
logger = init_module_logger(file_name)
|
logger = init_module_logger(file_name)
|
||||||
|
|
||||||
|
# 单次日志去重
|
||||||
if record_once:
|
if record_once:
|
||||||
log_key = f"{level.upper()}_{content}"
|
log_key = f"{level.upper()}_{content}"
|
||||||
if log_key in log_once_cache:
|
if log_key in log_once_cache:
|
||||||
return
|
return
|
||||||
log_once_cache.add(log_key)
|
log_once_cache.add(log_key)
|
||||||
|
|
||||||
|
# 日志级别分发
|
||||||
level_upper = level.upper()
|
level_upper = level.upper()
|
||||||
if level_upper == "DEBUG":
|
log_map = {
|
||||||
logger.debug(content)
|
"DEBUG": logger.debug,
|
||||||
elif level_upper == "WARNING":
|
"WARNING": logger.warning,
|
||||||
logger.warning(content)
|
"ERROR": logger.error,
|
||||||
elif level_upper == "ERROR":
|
"FATAL": logger.fatal,
|
||||||
logger.error(content)
|
"INFO": logger.info
|
||||||
elif level_upper == "FATAL":
|
}
|
||||||
logger.fatal(content)
|
log_func = log_map.get(level_upper, logger.info)
|
||||||
else:
|
log_func(content)
|
||||||
logger.info(content)
|
|
||||||
422
system_test.py
Normal file
422
system_test.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
ZMQ 脑电数据测试工具【语法错误修复版】
|
||||||
|
修复点:
|
||||||
|
1. dataclass 可变列表默认值报错
|
||||||
|
2. threading.Thread daemon 参数语法错误
|
||||||
|
适配:Python3.10、全链路 float64、ZMQ DEALER<->ROUTER
|
||||||
|
端口:8099(命令) / 8100(数据)
|
||||||
|
"""
|
||||||
|
import zmq
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Union, Tuple
|
||||||
|
from matplotlib.animation import FuncAnimation
|
||||||
|
|
||||||
|
# ===================== 1. 配置管理 =====================
|
||||||
|
@dataclass(frozen=True) # 冻结配置类
|
||||||
|
class TestConfig:
|
||||||
|
# 网络配置
|
||||||
|
SERVER_IP: str = "127.0.0.1"
|
||||||
|
CMD_PORT: int = 8099
|
||||||
|
DATA_PORT: int = 8100
|
||||||
|
|
||||||
|
# 硬件与时序
|
||||||
|
SAMPLE_RATE: int = 250
|
||||||
|
FRAME_INTERVAL_MS: int = 20
|
||||||
|
SEND_INTERVAL: float = FRAME_INTERVAL_MS / 1000
|
||||||
|
CHANNEL_NUMS: int = 66
|
||||||
|
FRAME_POINTS: int = 5
|
||||||
|
FILTER_OUT_CHAN: int = 64
|
||||||
|
FILTER_FRAME_POINTS: int = 50
|
||||||
|
|
||||||
|
# 数据类型 & 字节数 (float64 8字节)
|
||||||
|
DATA_DTYPE: np.dtype = np.float64
|
||||||
|
RAW_FRAME_BYTES: int = CHANNEL_NUMS * FRAME_POINTS * 8 # 66*5*8 = 2640
|
||||||
|
FILTER_FRAME_BYTES: int = FILTER_OUT_CHAN * FILTER_FRAME_POINTS * 8 # 25600
|
||||||
|
|
||||||
|
# 事件通道索引
|
||||||
|
EVENT_CHANNEL_IDX: int = -2
|
||||||
|
|
||||||
|
# 列表类型 使用 default_factory 规避可变默认值报错
|
||||||
|
EVENT_TAGS: List[int] = field(default_factory=lambda: [1, 2, 99])
|
||||||
|
SIM_SIGNAL_FREQ: List[float] = field(default_factory=lambda: [8.0, 9.0])
|
||||||
|
|
||||||
|
# 仿真噪声
|
||||||
|
NOISE_STD: float = 0.25
|
||||||
|
|
||||||
|
# 可视化配置
|
||||||
|
PLOT_TARGET_CHAN: int = 0
|
||||||
|
PLOT_WINDOW_LEN: int = 400
|
||||||
|
PLOT_REFRESH_INTERVAL: int = 50
|
||||||
|
|
||||||
|
# 日志限流
|
||||||
|
FRAME_ERR_INTERVAL: float = 3.0
|
||||||
|
|
||||||
|
# ZMQ 配置
|
||||||
|
SEND_RETRY_MAX: int = 3
|
||||||
|
SEND_RETRY_SLEEP: float = 0.01
|
||||||
|
ZMQ_HWM: int = 1000
|
||||||
|
|
||||||
|
# 初始化全局配置
|
||||||
|
CONFIG = TestConfig()
|
||||||
|
|
||||||
|
# ===================== 2. 全局状态管理 =====================
|
||||||
|
class GlobalState:
|
||||||
|
def __init__(self):
|
||||||
|
self.run_flag: bool = True
|
||||||
|
self.last_frame_err_time: float = 0.0
|
||||||
|
|
||||||
|
GLOBAL_STATE = GlobalState()
|
||||||
|
|
||||||
|
# ===================== 3. Matplotlib 中文初始化 =====================
|
||||||
|
def init_matplotlib():
|
||||||
|
# Windows 黑体,Linux/Mac 自行替换字体
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 修复负号乱码
|
||||||
|
|
||||||
|
init_matplotlib()
|
||||||
|
|
||||||
|
# ===================== 4. ZMQ DEALER 客户端 =====================
|
||||||
|
class ZmqDealerClient:
|
||||||
|
"""适配 ROUTER 的 DEALER 客户端,高频流式数据专用"""
|
||||||
|
def __init__(self, server_ip: str, port: int):
|
||||||
|
self.ctx: zmq.Context = zmq.Context()
|
||||||
|
self.socket: zmq.Socket = self.ctx.socket(zmq.DEALER)
|
||||||
|
self._configure_socket()
|
||||||
|
self.socket.connect(f"tcp://{server_ip}:{port}")
|
||||||
|
|
||||||
|
def _configure_socket(self):
|
||||||
|
"""套接字参数配置"""
|
||||||
|
self.socket.setsockopt(zmq.RCVHWM, CONFIG.ZMQ_HWM)
|
||||||
|
self.socket.setsockopt(zmq.SNDHWM, CONFIG.ZMQ_HWM)
|
||||||
|
self.socket.setsockopt(zmq.RCVTIMEO, 0)
|
||||||
|
self.socket.setsockopt(zmq.SNDTIMEO, 0)
|
||||||
|
|
||||||
|
def send_json(self, data: Dict) -> bool:
|
||||||
|
"""发送JSON命令,带重试机制"""
|
||||||
|
try:
|
||||||
|
payload = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[JSON序列化失败] {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
for _ in range(CONFIG.SEND_RETRY_MAX):
|
||||||
|
try:
|
||||||
|
self.socket.send_multipart([b"", payload])
|
||||||
|
return True
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[JSON发送异常] {e}")
|
||||||
|
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
||||||
|
print(f"[JSON发送重试失败]")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def send_bytes(self, data: bytes) -> bool:
|
||||||
|
"""发送二进制脑电数据,带重试"""
|
||||||
|
for _ in range(CONFIG.SEND_RETRY_MAX):
|
||||||
|
try:
|
||||||
|
self.socket.send_multipart([b"", data])
|
||||||
|
return True
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[二进制发送异常] {e}")
|
||||||
|
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
||||||
|
print(f"[二进制发送重试失败]")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def recv_json(self) -> Optional[Dict]:
|
||||||
|
"""接收JSON命令响应(标准3帧)"""
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart()
|
||||||
|
if len(frames) < 3:
|
||||||
|
self._log_frame_err(f"帧数异常: {len(frames)}")
|
||||||
|
return None
|
||||||
|
payload = frames[2].decode("utf-8")
|
||||||
|
return json.loads(payload)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self._log_frame_err("JSON解析失败")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
self._log_frame_err(f"接收异常: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def recv_bytes(self) -> Optional[bytes]:
|
||||||
|
"""接收滤波数据,兼容3/4帧格式"""
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart()
|
||||||
|
frame_len = len(frames)
|
||||||
|
if frame_len == 3:
|
||||||
|
payload = frames[2]
|
||||||
|
elif frame_len == 4:
|
||||||
|
payload = frames[3]
|
||||||
|
else:
|
||||||
|
self._log_frame_err(f"帧数异常: {frame_len}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(payload) != CONFIG.FILTER_FRAME_BYTES:
|
||||||
|
self._log_frame_err(f"字节不匹配: 期望{CONFIG.FILTER_FRAME_BYTES}, 实际{len(payload)}")
|
||||||
|
return None
|
||||||
|
return payload
|
||||||
|
except Exception as e:
|
||||||
|
self._log_frame_err(f"数据接收异常: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _log_frame_err(self, msg: str):
|
||||||
|
"""日志限流,防止刷屏"""
|
||||||
|
now = time.time()
|
||||||
|
if now - GLOBAL_STATE.last_frame_err_time > CONFIG.FRAME_ERR_INTERVAL:
|
||||||
|
print(f"[帧异常] {msg}")
|
||||||
|
GLOBAL_STATE.last_frame_err_time = now
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""优雅释放ZMQ资源"""
|
||||||
|
try:
|
||||||
|
self.socket.close(linger=0)
|
||||||
|
self.ctx.term()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[资源释放异常] {e}")
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
# ===================== 5. 仿真脑电数据生成 =====================
|
||||||
|
def generate_raw_eeg_frame(add_event: bool = False) -> np.ndarray:
|
||||||
|
"""生成单帧float64仿真脑电数据"""
|
||||||
|
t = np.linspace(
|
||||||
|
0, CONFIG.FRAME_POINTS / CONFIG.SAMPLE_RATE,
|
||||||
|
CONFIG.FRAME_POINTS, endpoint=False
|
||||||
|
)
|
||||||
|
eeg_frame = np.zeros(
|
||||||
|
(CONFIG.CHANNEL_NUMS, CONFIG.FRAME_POINTS),
|
||||||
|
dtype=CONFIG.DATA_DTYPE
|
||||||
|
)
|
||||||
|
|
||||||
|
# 模拟脑电信号 + 高斯噪声
|
||||||
|
for freq in CONFIG.SIM_SIGNAL_FREQ:
|
||||||
|
eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.sin(2 * np.pi * freq * t)
|
||||||
|
eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.random.normal(
|
||||||
|
0, CONFIG.NOISE_STD,
|
||||||
|
size=(CONFIG.FILTER_OUT_CHAN, CONFIG.FRAME_POINTS)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 事件通道处理
|
||||||
|
eeg_frame[CONFIG.EVENT_CHANNEL_IDX] = 0.0
|
||||||
|
if add_event:
|
||||||
|
event_pos = np.random.randint(0, CONFIG.FRAME_POINTS)
|
||||||
|
eeg_frame[CONFIG.EVENT_CHANNEL_IDX, event_pos] = np.random.choice(CONFIG.EVENT_TAGS)
|
||||||
|
|
||||||
|
# 预留通道置0
|
||||||
|
eeg_frame[-1] = 0.0
|
||||||
|
return eeg_frame
|
||||||
|
|
||||||
|
# ===================== 6. 后台工作线程 =====================
|
||||||
|
def start_cmd_response_thread(cmd_client: ZmqDealerClient):
|
||||||
|
"""命令响应接收线程"""
|
||||||
|
print("[线程-命令接收] 已启动")
|
||||||
|
while GLOBAL_STATE.run_flag:
|
||||||
|
msg = cmd_client.recv_json()
|
||||||
|
if msg:
|
||||||
|
print(f"\n【命令响应】{json.dumps(msg, ensure_ascii=False, indent=2)}")
|
||||||
|
time.sleep(0.01)
|
||||||
|
print("[线程-命令接收] 已退出")
|
||||||
|
|
||||||
|
def start_raw_eeg_send_thread(data_client: ZmqDealerClient):
|
||||||
|
"""原始脑电发送线程(20ms/帧)"""
|
||||||
|
print(f"[线程-原始数据发送] 20ms/帧 | 单帧{CONFIG.RAW_FRAME_BYTES}字节 | float64")
|
||||||
|
frame_count = 0
|
||||||
|
while GLOBAL_STATE.run_flag:
|
||||||
|
insert_event = (frame_count % 20 == 0)
|
||||||
|
eeg_frame = generate_raw_eeg_frame(add_event=insert_event)
|
||||||
|
frame_bytes = eeg_frame.tobytes()
|
||||||
|
|
||||||
|
# 字节校验
|
||||||
|
if len(frame_bytes) != CONFIG.RAW_FRAME_BYTES:
|
||||||
|
print(f"[字节警告] 期望{CONFIG.RAW_FRAME_BYTES}, 实际{len(frame_bytes)}")
|
||||||
|
time.sleep(CONFIG.SEND_INTERVAL)
|
||||||
|
frame_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
data_client.send_bytes(frame_bytes)
|
||||||
|
frame_count += 1
|
||||||
|
time.sleep(CONFIG.SEND_INTERVAL)
|
||||||
|
print("[线程-原始数据发送] 已退出")
|
||||||
|
|
||||||
|
def start_filter_data_recv_thread(data_client: ZmqDealerClient, plot_queue: List[np.ndarray]):
|
||||||
|
"""滤波数据接收线程"""
|
||||||
|
print(f"[线程-滤波数据接收] 单包{CONFIG.FILTER_FRAME_BYTES}字节 | float64")
|
||||||
|
while GLOBAL_STATE.run_flag:
|
||||||
|
raw_bytes = data_client.recv_bytes()
|
||||||
|
if not raw_bytes:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
filter_arr = np.frombuffer(raw_bytes, dtype=CONFIG.DATA_DTYPE)
|
||||||
|
filter_arr = filter_arr.reshape(CONFIG.FILTER_FRAME_POINTS, CONFIG.FILTER_OUT_CHAN)
|
||||||
|
plot_queue.append(filter_arr[:, CONFIG.PLOT_TARGET_CHAN])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[滤波数据解析异常] {e}")
|
||||||
|
continue
|
||||||
|
print("[线程-滤波数据接收] 已退出")
|
||||||
|
|
||||||
|
# ===================== 7. 实时波形可视化 =====================
|
||||||
|
def start_wave_visualization(plot_queue: List[np.ndarray]):
|
||||||
|
"""启动实时滤波波形绘图"""
|
||||||
|
fig, ax = plt.subplots(figsize=(14, 4))
|
||||||
|
x_axis = np.arange(0, CONFIG.PLOT_WINDOW_LEN)
|
||||||
|
wave_data = np.zeros(CONFIG.PLOT_WINDOW_LEN, dtype=CONFIG.DATA_DTYPE)
|
||||||
|
line, = ax.plot(x_axis, wave_data, color="#2E86AB", linewidth=1.2)
|
||||||
|
|
||||||
|
ax.set_title(
|
||||||
|
f"实时滤波脑电波形 | 通道 {CONFIG.PLOT_TARGET_CHAN} | {CONFIG.SAMPLE_RATE}Hz | float64",
|
||||||
|
fontsize=12
|
||||||
|
)
|
||||||
|
ax.set_ylim(-3.0, 3.0)
|
||||||
|
ax.grid(True, alpha=0.3, linestyle="--")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
def update_plot(_):
|
||||||
|
nonlocal wave_data
|
||||||
|
if plot_queue:
|
||||||
|
new_wave = plot_queue.pop(0)
|
||||||
|
wave_data = np.roll(wave_data, -len(new_wave))
|
||||||
|
wave_data[-len(new_wave)] = new_wave
|
||||||
|
line.set_ydata(wave_data)
|
||||||
|
return (line,)
|
||||||
|
|
||||||
|
ani = FuncAnimation(
|
||||||
|
fig, update_plot,
|
||||||
|
interval=CONFIG.PLOT_REFRESH_INTERVAL,
|
||||||
|
blit=True,
|
||||||
|
cache_frame_data=False
|
||||||
|
)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# ===================== 8. 全量业务测试用例 =====================
|
||||||
|
def run_full_test_cases(cmd_client: ZmqDealerClient):
|
||||||
|
"""全覆盖 zmqServer 所有命令:sync/targetFreqs/decoderClass/impedance/train/predict/rest"""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("开始执行全量命令测试用例")
|
||||||
|
print("="*60)
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# 1. 同步命令
|
||||||
|
print("\n[用例 1] 发送 sync 命令")
|
||||||
|
cmd_client.send_json({"method": "sync", "params": {}})
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# 2. 设置目标频率
|
||||||
|
print("\n[用例 2] 发送 targetFreqs = [8.0, 9.0]")
|
||||||
|
cmd_client.send_json({"method": "targetFreqs", "params": [8.0, 9.0]})
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# 3. 切换解码器
|
||||||
|
print("\n[用例 3] 切换解码器为 ssmvep")
|
||||||
|
cmd_client.send_json({"method": "decoderClass", "params": "ssmvep"})
|
||||||
|
time.sleep(2)
|
||||||
|
print("\n[用例 3-2] 切换解码器为 mi")
|
||||||
|
cmd_client.send_json({"method": "decoderClass", "params": "mi"})
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# 4. 阻抗检测开关
|
||||||
|
print("\n[用例 4] 开启阻抗检测 impedance=1")
|
||||||
|
cmd_client.send_json({"method": "impedance", "params": 1})
|
||||||
|
time.sleep(1)
|
||||||
|
print("\n[用例 4-2] 关闭阻抗检测 impedance=2")
|
||||||
|
cmd_client.send_json({"method": "impedance", "params": 2})
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# 5. 训练模式
|
||||||
|
print("\n[用例 5] 启动训练 train,标签=1")
|
||||||
|
cmd_client.send_json({"method": "train", "params": 1})
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
# # 6. 休息模式
|
||||||
|
# print("\n[用例 6] 切换 rest 休息模式")
|
||||||
|
# cmd_client.send_json({"method": "rest", "params": {}})
|
||||||
|
# time.sleep(1)
|
||||||
|
|
||||||
|
# 7. 启动解码
|
||||||
|
print("\n[用例 7] 启动解码 predict=1")
|
||||||
|
cmd_client.send_json({"method": "predict", "params": 1})
|
||||||
|
time.sleep(4)
|
||||||
|
|
||||||
|
# # 8. 非法命令(异常测试)
|
||||||
|
# print("\n[用例 8] 发送非法命令 test_cmd_illegal")
|
||||||
|
# cmd_client.send_json({"method": "test_cmd_illegal", "params": {}})
|
||||||
|
# time.sleep(1)
|
||||||
|
|
||||||
|
# # 9. 停止解码
|
||||||
|
# print("\n[用例 9] 停止解码 predict=2")
|
||||||
|
# cmd_client.send_json({"method": "predict", "params": 2})
|
||||||
|
# time.sleep(2)
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("所有测试用例执行完毕")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# ===================== 主程序入口(修复线程语法) =====================
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("="*60)
|
||||||
|
print("ZMQ 脑电仿真测试工具 启动")
|
||||||
|
print(f"命令端口: {CONFIG.CMD_PORT} | 数据端口: {CONFIG.DATA_PORT}")
|
||||||
|
print(f"原始帧{CONFIG.RAW_FRAME_BYTES}字节 | 滤波帧{CONFIG.FILTER_FRAME_BYTES}字节 | float64")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.CMD_PORT) as cmd_client, \
|
||||||
|
ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.DATA_PORT) as data_client:
|
||||||
|
|
||||||
|
plot_queue = []
|
||||||
|
|
||||||
|
# ========== 重点修复:线程语法,daemon 移出 args ==========
|
||||||
|
# 命令接收线程
|
||||||
|
t_cmd = threading.Thread(
|
||||||
|
target=start_cmd_response_thread,
|
||||||
|
args=(cmd_client,), # 单元素元组保留逗号
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
# 原始数据发送线程
|
||||||
|
t_eeg = threading.Thread(
|
||||||
|
target=start_raw_eeg_send_thread,
|
||||||
|
args=(data_client,),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
# 滤波数据接收线程
|
||||||
|
t_filter = threading.Thread(
|
||||||
|
target=start_filter_data_recv_thread,
|
||||||
|
args=(data_client, plot_queue),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 启动线程
|
||||||
|
t_cmd.start()
|
||||||
|
t_eeg.start()
|
||||||
|
t_filter.start()
|
||||||
|
|
||||||
|
# 执行测试用例
|
||||||
|
run_full_test_cases(cmd_client)
|
||||||
|
|
||||||
|
# 启动可视化(阻塞主线程)
|
||||||
|
print("\n[提示] 波形窗口已启动,关闭窗口 / Ctrl+C 退出程序")
|
||||||
|
start_wave_visualization(plot_queue)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n[用户中断] 接收到 Ctrl+C,准备退出...")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n[程序异常] {e}")
|
||||||
|
finally:
|
||||||
|
# 停止所有后台线程
|
||||||
|
GLOBAL_STATE.run_flag = False
|
||||||
|
time.sleep(0.2)
|
||||||
|
print("程序已安全退出")
|
||||||
Reference in New Issue
Block a user