Compare commits

...

15 Commits

Author SHA1 Message Date
c27e250fad update log config 2026-06-13 19:47:27 +08:00
66c0b71b89 update ip 2026-06-13 17:35:46 +08:00
5c7b73b7a4 add log 2026-06-13 16:49:29 +08:00
9690971f43 update 2026-06-13 11:54:58 +08:00
5a5f103ef6 update log 2026-06-13 10:06:29 +08:00
b31bb18dfe update 2026-06-12 15:21:47 +08:00
38480a2ca3 remove release 2026-06-12 14:30:11 +08:00
62e7cab5be add sleep if open impedence 2026-06-12 13:56:48 +08:00
Ivey Song
b26ae2ce3c beta psd 独立线程 2026-06-12 11:33:48 +08:00
5488626112 update 2026-06-11 14:29:43 +08:00
d59b0f695f realeas v1 2026-06-11 11:55:35 +08:00
0570d41439 bug fix 2026-06-11 11:06:59 +08:00
4574798d86 release v1 2026-06-11 09:21:57 +08:00
d480107b37 update 2026-06-11 08:11:29 +08:00
2d70fc9956 update log path 2026-06-11 08:04:08 +08:00
12 changed files with 344 additions and 155 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
__pycache__/ __pycache__/
# Distribution / packaging # Distribution / packaging
release/
build/ build/
dist/ dist/
dist_nuitka/ dist_nuitka/

View File

@@ -63,7 +63,7 @@ class Decoder_main(threading.Thread):
# 注册滤波结果回调(示例:打印数据形状) # 注册滤波结果回调(示例:打印数据形状)
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
# 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机 # 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机
self.sliding_filter.beta_broadcast_callback = lambda v: self.zmqServer.broadcast_message('beta_psd', v) self.sliding_filter.set_beta_broadcast_callback(lambda v: self.zmqServer.broadcast_message('beta_psd', v))
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号 def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples) # data: (chans, samples)
@@ -157,8 +157,8 @@ class Decoder_main(threading.Thread):
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band') # self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
def parameter_init(self,bandPass_low,bandPass_high): def parameter_init(self,bandPass_low,bandPass_high):
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息 self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息 ssmvep [50, 550]
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch ssmevep [50, 575]
self.trainData = [] #训练数据 self.trainData = [] #训练数据
self.trainLabel = [] #训练标签 self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据 self.plotData = [] #报告分析数据
@@ -206,6 +206,9 @@ class Decoder_main(threading.Thread):
self.zmqServer.state_mode = 'rest' self.zmqServer.state_mode = 'rest'
try: try:
if self.zmqServer.open_Impedance:
time.sleep(0.005)
continue
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs': if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.decoder_SSVEP() self.decoder_SSVEP()
elif self.decoder_class == 'ssmvep': elif self.decoder_class == 'ssmvep':
@@ -215,7 +218,7 @@ class Decoder_main(threading.Thread):
else: else:
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005) time.sleep(0.005)
continue; continue
self.zmqServer.paradigmBuffer.getData(25) self.zmqServer.paradigmBuffer.getData(25)
except Exception as e: except Exception as e:
algo_log(f"Decoder Loop Error: {e}") algo_log(f"Decoder Loop Error: {e}")
@@ -233,7 +236,7 @@ class Decoder_main(threading.Thread):
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码 if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
return return
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50) data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
# algo_log(f"SSVEP取出的{data.shape}, data = {data[:20]}", level="DEBUG") # algo_log(f"SSVEP取出的{data.shape}, data = {data[:, :10]}", level="DEBUG")
data = data[:self.n_chan, :] data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热 if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
@@ -286,6 +289,7 @@ class Decoder_main(threading.Thread):
and self.trainLabel.count(self.currentLabel) < self.single_train: and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial) self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel) self.trainLabel.append(self.currentLabel)
algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG")
else: else:
time.sleep(0.0001) time.sleep(0.0001)
return return
@@ -425,7 +429,7 @@ class Decoder_main(threading.Thread):
y_pred = torch.max(Cls, 1)[1] y_pred = torch.max(Cls, 1)[1]
self.plotLabel.append(int(y_pred.item())) self.plotLabel.append(int(y_pred.item()))
algo_log(f"MI运动意图识别: {y_pred}") algo_log(f"MI运动意图识别: {y_pred}")
self.zmqServer.broadcast_message('paradigm', int(y_pred.item())) self.zmqServer.broadcast_message('result', int(y_pred.item()))
end = time.time() end = time.time()
algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。') algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态 else: # 休息状态

View File

@@ -318,11 +318,7 @@ class ExP():
train_pred = torch.max(outputs, 1)[1] train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
algo_log('Epoch:', e, algo_log(f"Epoch = {e}, Train loss = {loss.detach().cpu().numpy():.6f}, Test loss = {loss_test.detach().cpu().numpy():.6f}, Train accuracy = {train_acc:.6f}, Test accuracy = {acc:.6f}", level="debug")
' Train loss: %.6f' % loss.detach().cpu().numpy(),
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
' Train accuracy %.6f' % train_acc,
' Test accuracy is %.6f' % acc, level="debug")
self.log_write.write(str(e) + " " + str(acc) + "\n") self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1 num = num + 1
@@ -335,8 +331,8 @@ class ExP():
torch.save(self.model, model_path) torch.save(self.model, model_path)
averAcc = averAcc / num averAcc = averAcc / num
algo_log('The average accuracy is:', averAcc, level="debug") algo_log(f"The average accuracy is: {averAcc}", level="debug")
algo_log('The best accuracy is:', bestAcc, level="debug") algo_log(f"The best accuracy is: {bestAcc}", level="debug")
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
@@ -366,12 +362,13 @@ def onlineTrain(data_queue,result_queue):
data = data_queue.get(timeout=30) data = data_queue.get(timeout=30)
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan'] all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
exp = ExP(n_chan) exp = ExP(n_chan)
algo_log('训练参数: ',np.shape(all_data),np.shape(all_label),model_path, level="debug") algo_log(f"训练参数: {np.shape(all_data)}, {np.shape(all_label)}, {model_path}", level="debug")
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path) bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug") algo_log(f"THE BEST ACCURACY IS {str(bestAcc)}", level="debug")
endtime = datetime.datetime.now() endtime = datetime.datetime.now()
algo_log('train duration: ',str(endtime - starttime), level="debug") algo_log(f"train duration: {endtime - starttime}", level="debug")
# 将模型或参数传回 # 将模型或参数传回
result_queue.put({ result_queue.put({
@@ -387,7 +384,7 @@ def offlineTrain(all_data,all_label,modelPath):
# seed_n = np.random.randint(2025) # seed_n = np.random.randint(2025)
seed_n = 1877 seed_n = 1877
algo_log('seed is ' + str(seed_n), level="debug") algo_log(f"seed is {seed_n}", level="debug")
random.seed(seed_n) random.seed(seed_n)
np.random.seed(seed_n) np.random.seed(seed_n)
torch.manual_seed(seed_n) torch.manual_seed(seed_n)
@@ -400,7 +397,7 @@ def offlineTrain(all_data,all_label,modelPath):
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug") algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
endtime = datetime.datetime.now() endtime = datetime.datetime.now()
algo_log('train duration: ',str(endtime - starttime), level="debug") algo_log(f"train duration: {endtime - starttime}", level="debug")

View File

@@ -13,6 +13,13 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and
6. decoder class切换问题 6. decoder class切换问题
7. decoder_class切换时数据重置、各类参数重置 7. decoder_class切换时数据重置、各类参数重置
# realease log
- 2026年6月11日11:29:17 打包第一版包名runDecoder.dist_v0.0.0_beta_20260611.7z
- 2026年6月11日12:00:00 打包第二版包名runDecoder.dist_v0.0.0_beta_20260611.7z
- 修复上位机先发decoder_class, 后发open_impedence 带来decoder_main thread 阻塞问题
- 2026年6月12日15:05:47 runDecoder.dist_v0.0.2_beta_20260612
- 优化filter读数精度
# 常用命令 # 常用命令
source activate 3in1Py310 source activate 3in1Py310
@@ -22,7 +29,15 @@ python ZeroMQClient_mock.py
python filter_test.py python filter_test.py
python upperHost_stimmock/MI_headless.py python upperHost_stimmock/MI_headless.py
# 打包命令
./nuitka_3in1_package.sh
# TODO
1. mvep是否要把list freq 开放到config
2. 滤波器参数 放到config文件
# 遗留问题 # debug log
1. mvep是否要把list freq 开放到config ## MI
Epoch采集完成|收到命令: {'method': 'train'|取出的
收到命令: {'method': 'train'|收到命令: {'method': 'train'|收到命令: {'method': 'predict'|事件检测到

View File

@@ -18,11 +18,11 @@ from logs.log import algo_log
class FbccaDw: class FbccaDw:
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method): def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
algo_log('******************************************', level="debug") algo_log('******************************************', level="debug")
algo_log('parameter list', level="debug") algo_log('parameter list',level="debug")
algo_log('target:', num_target, level="debug") algo_log(f"target: {num_target}", level="debug")
algo_log('number of filter bank:', num_filter, level="debug") algo_log(f"number of filter bank: {num_filter}", level="debug")
algo_log('parameter:', parameter, level="debug") algo_log(f"parameter: {parameter}", level="debug")
algo_log('width:', width, level="debug") algo_log(f"width: {width}", level="debug")
self.phase = 0 self.phase = 0
self.bandWidth = width self.bandWidth = width
self.winNum = winNum self.winNum = winNum

View File

@@ -5,6 +5,7 @@
import numpy as np import numpy as np
import time import time
import threading import threading
import queue
from scipy import signal from scipy import signal
from logs.log import algo_log from logs.log import algo_log
import sys import sys
@@ -93,7 +94,74 @@ class FilterRingBuffer:
self.has_new_data = False # 重置时清空新数据标记 self.has_new_data = False # 重置时清空新数据标记
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现 # 2. 独立 Beta PSD 计算线程(避免阻塞滤波主循环的 200ms 定时
# -----------------------------------------------------------------------------
class BetaPsdCalculator(threading.Thread):
"""独立的 Beta PSD 计算线程,使用队列与滤波主线程解耦"""
def __init__(self, fs=250, window_size=750):
super().__init__(daemon=True)
self.fs = fs
self.window_size = window_size
self._beta_calc = Beta_Calculate(Threshold_value_low=0, Threshold_value_high=0, fs=fs)
self._input_queue = queue.Queue(maxsize=2)
self._running = threading.Event()
self._running.set()
self._latest_beta = None
self._beta_lock = threading.Lock()
self.beta_broadcast_callback = None
def push_data(self, data):
"""供外部调用的线程安全数据推送接口"""
try:
self._input_queue.put_nowait(data)
except queue.Full:
try:
self._input_queue.get_nowait()
except queue.Empty:
pass
try:
self._input_queue.put_nowait(data)
except queue.Full:
pass
def get_latest_beta(self):
"""获取最新的 beta 值(线程安全)"""
with self._beta_lock:
return self._latest_beta
def run(self):
while self._running.is_set():
try:
data = self._input_queue.get(timeout=1.5)
if data is None:
break
try:
beta_psd, _, _ = self._beta_calc.calculate_all(
data, fs=self.fs, nperseg=min(self.window_size, data.shape[1])
)
with self._beta_lock:
self._latest_beta = round(float(beta_psd), 3)
if self.beta_broadcast_callback is not None:
self.beta_broadcast_callback(self._latest_beta)
except Exception as e:
algo_log(f"Beta PSD 计算异常: {e}", level='error')
except queue.Empty:
pass
def stop(self):
"""停止计算线程"""
self._running.clear()
try:
self._input_queue.put_nowait(None)
except queue.Full:
pass
if self.is_alive():
self.join(timeout=2)
# -----------------------------------------------------------------------------
# 3. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SlidingFilter(threading.Thread): class SlidingFilter(threading.Thread):
def __init__( def __init__(
@@ -121,23 +189,33 @@ class SlidingFilter(threading.Thread):
self.running.set() self.running.set()
# 滤波结果回调(外部可注册,获取滤波后的数据) # 滤波结果回调(外部可注册,获取滤波后的数据)
self.filter_result_callback = None self.filter_result_callback = None
# beta_psd 广播回调(外部注册,用于走 zmqServer 8099 端口发送)
self.beta_broadcast_callback = None
# beta 计算器Fp1/Fp2 通道,索引 0/1
self._beta_calc = Beta_Calculate(Threshold_value_low=0, Threshold_value_high=0, fs=srate)
# beta 每秒触发计数200ms步长5次 = 1s # beta 每秒触发计数200ms步长5次 = 1s
self._beta_step_counter = 0 self._beta_step_counter = 0
self._beta_steps_per_second = max(1, int(round(1.0 / step_sec))) # 5 self._beta_steps_per_second = max(1, int(round(1.0 / step_sec))) # 5
self.slide_window = None # 滑动窗口缓存 (n_chan, window_size)
self.slide_ready = False # 窗口是否已填满初始数据
# 预计算滤波器系数(仅执行一次) # 预计算滤波器系数(仅执行一次)
self._init_filters() self._init_filters()
# 独立的 Beta 计算线程(避免阻塞滤波主循环)
self._beta_thread = BetaPsdCalculator(fs=srate, window_size=self.window_size)
def start(self):
"""同时启动 Beta 计算线程和滤波主线程"""
self._beta_thread.start()
super().start()
def set_beta_broadcast_callback(self, callback):
"""注册 Beta PSD 广播回调函数"""
self._beta_thread.beta_broadcast_callback = callback
def _init_filters(self): def _init_filters(self):
"""预计算所有滤波器系数(仅执行一次)""" """预计算所有滤波器系数(仅执行一次)"""
# 50Hz工频陷波Q=30工业标准 # 50Hz工频陷波Q=30工业标准
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate) self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
# 8~30Hz带通FIR65阶线性相位 # 0.5~45Hz带通FIR65阶线性相位
self.b_bp = signal.firwin( self.b_bp = signal.firwin(
numtaps=65, numtaps=65,
cutoff=[0.5/(self.srate/2), 45/(self.srate/2)], cutoff=[0.5/(self.srate/2), 45/(self.srate/2)],
@@ -162,50 +240,60 @@ class SlidingFilter(threading.Thread):
def run(self): def run(self):
"""线程主逻辑精确200ms触发一次滤波""" """线程主逻辑精确200ms触发一次滤波"""
interval = self.step_sec # 200ms = 0.2 interval = self.step_sec # 0.2s
next_run_time = time.perf_counter() # 以启动时刻为绝对时间基准(核心改动)
while self.running.is_set(): base_time = time.perf_counter()
# 1. 精确定时等待 frame_count = 0 # 帧计数器,用于对齐时序
current_time = time.perf_counter()
if current_time < next_run_time:
time.sleep(next_run_time - current_time)
next_run_time += interval
else:
algo_log("滤波耗时超过200ms定时偏移", level='debug')
next_run_time = time.perf_counter() + interval
# ========== 新增核心判断:无新数据则直接跳过 ========== while self.running.is_set():
# 计算理论执行时刻:严格按帧序号 × 步长
expect_time = base_time + frame_count * interval
current_time = time.perf_counter()
# 精确定时等待
if current_time < expect_time:
time.sleep(expect_time - current_time)
else:
# 处理超时:仅告警,不重置基准(防止累积偏移)
algo_log(f"滤波任务超时,偏移 {(current_time - expect_time)*1000:.1f} ms", level='debug')
frame_count += 1 # 帧序号自增,保证周期绝对稳定
if not self.ring_buffer.check_and_clear_new_data(): if not self.ring_buffer.check_and_clear_new_data():
# 无新数据,不执行滤波、不发送数据 # 无新数据,不执行滤波、不发送数据
continue continue
# 2. 有新数据,才执行原有滤波逻辑 # ========== 原有滤波逻辑 ==========
try: try:
window_data = self.ring_buffer.get_latest_n_points(self.window_size) if not self.slide_ready:
if window_data is None: # 阶段1首次填满3s初始窗口
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}", level='debug') full_data = self.ring_buffer.get_latest_n_points(self.window_size)
continue if full_data is None:
algo_log("初始窗口数据不足", level='debug')
continue
self.slide_window = full_data
self.slide_ready = True
else:
# 阶段2正常滑动 → 取最新50个新点增量拼接
new_step_data = self.ring_buffer.get_latest_n_points(self.step_size)
if new_step_data is None:
algo_log("滑动步长数据不足", level='debug')
continue
# 增量滑动丢弃前50点拼接新50点标准滑动窗口
self.slide_window = np.hstack([
self.slide_window[:, self.step_size:],
new_step_data
])
filtered_data, filtered_full = self._filter_window_data(window_data) filtered_data, filtered_full = self._filter_window_data(self.slide_window[:64, :])
# algo_log(f"滤波后{filtered_data.shape}数据", level='debug')
# ========== beta_psd 每秒计算一次Fp1/Fp2通道索引 0/1========== # Beta PSD 每秒计算一次
self._beta_step_counter += 1 self._beta_step_counter += 1
if self._beta_step_counter >= self._beta_steps_per_second: if self._beta_step_counter >= self._beta_steps_per_second:
self._beta_step_counter = 0 self._beta_step_counter = 0
try: self._beta_thread.push_data(filtered_full[:2, :])
# 直接使用已滤波的完整3s数据的前两通道Fp1/Fp2
filter_betadata = filtered_full[:2, :] # shape (2, 750)
beta_psd, _, _ = self._beta_calc.calculate_all(
filter_betadata, fs=self.srate, nperseg=min(self.window_size, filter_betadata.shape[1])
)
if self.beta_broadcast_callback is not None:
self.beta_broadcast_callback(round(float(beta_psd), 3))
except Exception as be:
algo_log(f"beta_psd计算异常: {be}", level='error')
if self.filter_result_callback is not None: if self.filter_result_callback is not None:
self.filter_result_callback(filtered_data[:64, :]) self.filter_result_callback(filtered_data)
except Exception as e: except Exception as e:
algo_log(f"滤波执行异常: {e}", level='error') algo_log(f"滤波执行异常: {e}", level='error')
@@ -214,17 +302,11 @@ class SlidingFilter(threading.Thread):
self.filter_result_callback = callback self.filter_result_callback = callback
def stop(self): def stop(self):
"""停止滤波线程(安全版)""" """停止滤波线程和 Beta 计算线程"""
# 1. 先设置停止标志Event.clear()是线程安全的) self._beta_thread.stop()
self.running.clear() self.running.clear()
# 2. 核心修复只有线程已启动且正在运行时才调用join
if self.is_alive(): if self.is_alive():
# 等待线程正常退出最多1秒
self.join(timeout=1) self.join(timeout=1)
# 超时未退出时打印警告,便于排查问题
if self.is_alive(): if self.is_alive():
algo_log("警告滤波线程在1秒内未正常退出可能存在阻塞操作", level="WARNING") algo_log("警告滤波线程在1秒内未正常退出可能存在阻塞操作", level="WARNING")
# 3. 无论线程是否启动,都打印停止日志
algo_log("滤波线程已停止") algo_log("滤波线程已停止")

View File

@@ -152,7 +152,8 @@ class zmqServer(threading.Thread):
msg = {'method': method, 'params': params} msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8') msg_bytes = json.dumps(msg).encode('utf-8')
algo_log(f"发送命令结果: {msg}", level="DEBUG") if msg['method'] != 'beta_psd':
algo_log(f"发送命令结果: {msg}", level="DEBUG")
# 广播到所有命令客户端 # 广播到所有命令客户端
for client_id in list(self.cmd_clients): for client_id in list(self.cmd_clients):
@@ -179,7 +180,7 @@ class zmqServer(threading.Thread):
# 转置为上位机需要的[50, 通道数]格式 # 转置为上位机需要的[50, 通道数]格式
filtered_data = filtered_data.T.astype(np.float64) filtered_data = filtered_data.T.astype(np.float64)
send_buf = filtered_data.tobytes() send_buf = filtered_data.tobytes()
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True) # algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
self.data_send_queue.put(send_buf) self.data_send_queue.put(send_buf)
def _process_data_send_queue(self): def _process_data_send_queue(self):
@@ -225,6 +226,9 @@ class zmqServer(threading.Thread):
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR") algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"}) self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
return return
except Exception as e:
algo_log(f"_handle_cmd_message exception: {e}", level="ERROR")
return
algo_log(f"收到命令: {message}", level="INFO") algo_log(f"收到命令: {message}", level="INFO")
method = message.get("method") method = message.get("method")
@@ -270,6 +274,22 @@ class zmqServer(threading.Thread):
elif params == 2: #停止解码 elif params == 2: #停止解码
self.IsExitApp = True self.IsExitApp = True
self.running = False self.running = False
resp = {
"method": "predict_response",
"params": {
"code": 200,
"message": "ok"
}
}
try:
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
algo_log(f"predict 命令已即时回复客户端 {ident}", level="DEBUG")
except Exception as e:
algo_log(f"predict 命令回复失败: {e}", level="ERROR")
return
elif method == "rest": elif method == "rest":
self.state_mode = 'rest' self.state_mode = 'rest'
elif method == "impedance": elif method == "impedance":
@@ -353,24 +373,24 @@ class zmqServer(threading.Thread):
def detect_event(self, samples): def detect_event(self, samples):
self.pack_contain_event = False self.pack_contain_event = False
# 第65通道为事件通道 # 第65通道为事件通道
event = int(samples[-2][0]) events = np.array(samples[-2], dtype=np.int32).tolist()
# for idx, event in enumerate(events): for idx, event in enumerate(events):
if event in self.events: if event in self.events:
new_key = "".join( new_key = "".join(
[ [
str(event), str(event),
datetime.datetime.now().strftime("%Y-%m-%d \ datetime.datetime.now().strftime("%Y-%m-%d \
-%H-%M-%S"), -%H-%M-%S"),
] ]
) )
self.currentLabel = event self.currentLabel = event
if event == self.predict_event: if event == self.predict_event:
self.count_events[new_key] = self.latency + 1 self.count_events[new_key] = self.latency + 1
else: else:
self.count_events[new_key] = self.train_latency + 1 self.count_events[new_key] = self.train_latency + 1
self.event_inner_idx = self.device_info['frame_points'] - 1 self.event_inner_idx = idx
# algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG") algo_log(f"事件检测到: {events},索引: {idx}", level="DEBUG")
self.pack_contain_event = True self.pack_contain_event = True
# 倒计时并清理过期事件 # 倒计时并清理过期事件
drop_items = [] drop_items = []
@@ -415,7 +435,8 @@ class zmqServer(threading.Thread):
break break
except Exception as e: except Exception as e:
algo_log(f"服务器主循环异常: {e}", level="ERROR") algo_log(f"服务器主循环异常: {str(e)}", level="ERROR")
return
finally: finally:
self.running = False self.running = False
# 优雅关闭所有资源 # 优雅关闭所有资源

View File

@@ -18,11 +18,23 @@ Upper_Port = 8088
Decoder_Host = 127.0.0.1 Decoder_Host = 127.0.0.1
Decoder_Port = 8099 Decoder_Port = 8099
Serial_port = COM44 Serial_port = COM44
algo_log_level = DEBUG
console_output = 1
save_train_data = 0 save_train_data = 0
zmqServer_host = 127.0.0.1 zmqServer_host = 127.0.0.1
[algo_log]
# ========== 文件日志配置 ==========
file_log_enable = true
file_log_level = DEBUG
log_path = exe
retention_days = 3
# ========== 控制台/黑框配置 ==========
console_enable = true
console_show_window = true
console_log_level = DEBUG
; 64 导设备配置 ; 64 导设备配置
[device_type_1] [device_type_1]
sample_rate = 250 sample_rate = 250

View File

@@ -56,7 +56,7 @@ EMPTY_FRAME = b""
# 仿真信号配置 # 仿真信号配置
TARGET_CHANNEL = 0 TARGET_CHANNEL = 0
SIGNAL_FREQ_LIST = [3, 13] SIGNAL_FREQ_LIST = [13]
SIGNAL_AMP = 1.8 SIGNAL_AMP = 1.8
NOISE_GAUSSIAN_AMP = 0.4 NOISE_GAUSSIAN_AMP = 0.4
NOISE_POWER50_AMP = 0.3 NOISE_POWER50_AMP = 0.3
@@ -128,8 +128,8 @@ def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
sig = 0.0 sig = 0.0
for freq in SIGNAL_FREQ_LIST: for freq in SIGNAL_FREQ_LIST:
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr) sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr) # sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point) # sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
data[:, ch] = sig data[:, ch] = sig
# 事件通道、保留通道 # 事件通道、保留通道

View File

@@ -1,114 +1,172 @@
import os import os
import sys
from pathlib import Path
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
import inspect import inspect
try:
import win32gui
import win32con
WIN32_AVAILABLE = True
except ImportError:
WIN32_AVAILABLE = False
from PubLibrary.InifileHelper import IniRead from PubLibrary.InifileHelper import IniRead
# 全局配置 # ===================== 新增:获取 EXE 同级目录 =====================
console_output = IniRead('system', 'console_output', '1') def get_app_root():
log_level = IniRead('system', 'algo_log_level', 'INFO') """获取 runDecoder.exe 所在的真实根目录(兼容 onefile / standalone"""
if getattr(sys, 'frozen', False):
# Nuitka / PyInstaller 打包后走这里
app_path = sys.executable
else:
# 本地源码运行时,取当前脚本目录
app_path = os.path.abspath(__file__)
return os.path.dirname(app_path)
# 程序根目录exe 同级)
APP_ROOT = Path(get_app_root())
# 日志文件夹名exe 同级下 logs 目录
DEFAULT_LOG_DIR = APP_ROOT / "logs"
# ===================== 读取 [algo_log] 配置 =====================
# 文件日志
FILE_LOG_ENABLE = IniRead("algo_log", "file_log_enable", "true").lower() == "true"
FILE_LOG_LEVEL = IniRead("algo_log", "file_log_level", "DEBUG").upper()
# 优先级:配置文件 > 默认exe同级logs
CFG_LOG_PATH = IniRead("algo_log", "log_path", "").strip()
if CFG_LOG_PATH == "exe":
LOG_DIR = DEFAULT_LOG_DIR
else:
LOG_DIR = Path(CFG_LOG_PATH)
LOG_RETENTION_DAYS = int(IniRead("algo_log", "retention_days", 3))
# 控制台日志 + 黑框控制
CONSOLE_ENABLE = IniRead("algo_log", "console_enable", "true").lower() == "true"
CONSOLE_SHOW_WINDOW = IniRead("algo_log", "console_show_window", "true").lower() == "true"
CONSOLE_LOG_LEVEL = IniRead("algo_log", "console_log_level", "INFO").upper()
# ===================== 全局常量与缓存 =====================
log_once_cache = set() log_once_cache = set()
logger_cache = {} logger_cache = {}
LOG_RETENTION_DAYS = 3
LOG_DIR = './logs/'
LOG_FILE_PREFIX = 'algo_log_' LOG_FILE_PREFIX = 'algo_log_'
# 确保日志目录存在
LOG_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR_STR = str(LOG_DIR) + "\\"
# 日志格式:时间 - 日志器名 - 级别 - 文件名:行号 - 函数名 - 日志内容 # 日志格式
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S' DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
# 日志级别映射
LEVEL_MAP = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"FATAL": logging.FATAL
}
FILE_LOG_LEVEL_INT = LEVEL_MAP.get(FILE_LOG_LEVEL, logging.INFO)
CONSOLE_LOG_LEVEL_INT = LEVEL_MAP.get(CONSOLE_LOG_LEVEL, logging.INFO)
def clean_old_logs(): # ===================== Windows 控制台黑框显示/隐藏 =====================
"""清理超过指定天数的旧日志文件""" def control_console_window():
if not sys.platform.startswith("win") or not WIN32_AVAILABLE:
return
try: try:
if not os.path.exists(LOG_DIR): hwnd = win32gui.GetForegroundWindow()
if CONSOLE_SHOW_WINDOW:
win32gui.ShowWindow(hwnd, win32con.SW_SHOW)
else:
win32gui.ShowWindow(hwnd, win32con.SW_HIDE)
except Exception:
pass
control_console_window()
# ===================== 清理过期日志 =====================
def clean_old_logs():
try:
if not LOG_DIR.exists():
return return
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS) expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
for filename in os.listdir(LOG_DIR): for filename in os.listdir(LOG_DIR):
if not filename.startswith(LOG_FILE_PREFIX) or not filename.endswith('.log'): if not (filename.startswith(LOG_FILE_PREFIX) and filename.endswith('.log')):
continue continue
date_str = filename[len(LOG_FILE_PREFIX):-4] date_str = filename[len(LOG_FILE_PREFIX):-4]
try: try:
file_date = datetime.strptime(date_str, '%Y-%m-%d') file_date = datetime.strptime(date_str, '%Y-%m-%d')
if file_date < expire_date: if file_date < expire_date:
file_path = os.path.join(LOG_DIR, filename) file_path = LOG_DIR / filename
os.remove(file_path) os.remove(file_path)
print(f"清理过期日志: {file_path}")
except ValueError: except ValueError:
continue continue
except Exception as e: except Exception:
print(f"清理旧日志异常: {str(e)}") pass
# ===================== 初始化日志器 =====================
def init_module_logger(logger_name): def init_module_logger(logger_name):
"""初始化日志器 + 清理旧日志"""
os.makedirs(LOG_DIR, exist_ok=True)
clean_old_logs()
current_date = datetime.now().strftime("%Y-%m-%d")
log_file = os.path.join(LOG_DIR, f"{LOG_FILE_PREFIX}{current_date}.log")
if logger_name in logger_cache: if logger_name in logger_cache:
return logger_cache[logger_name] return logger_cache[logger_name]
clean_old_logs()
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
logger.setLevel(log_level) logger.setLevel(logging.DEBUG)
if logger.handlers: if logger.handlers:
logger_cache[logger_name] = logger logger_cache[logger_name] = logger
return logger return logger
# 文件输出处理器
file_handler = RotatingFileHandler(
log_file,
maxBytes=10 * 1024 * 1024,
backupCount=10,
encoding='utf-8'
)
formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT) formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 控制台输出 # 文件日志
if console_output: if FILE_LOG_ENABLE:
console_handler = logging.StreamHandler() current_date = datetime.now().strftime("%Y-%m-%d")
log_file = LOG_DIR / f"{LOG_FILE_PREFIX}{current_date}.log"
file_handler = RotatingFileHandler(
log_file,
maxBytes=10 * 1024 * 1024,
backupCount=10,
encoding='utf-8'
)
file_handler.setFormatter(formatter)
file_handler.setLevel(FILE_LOG_LEVEL_INT)
logger.addHandler(file_handler)
# 控制台日志
if CONSOLE_ENABLE:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
console_handler.setLevel(CONSOLE_LOG_LEVEL_INT)
logger.addHandler(console_handler) logger.addHandler(console_handler)
logger_cache[logger_name] = logger logger_cache[logger_name] = logger
return logger return logger
# ===================== 对外日志入口函数 =====================
def algo_log(content, level="INFO", record_once=False): def algo_log(content, level="INFO", record_once=False):
""" frame = inspect.currentframe()
日志入口函数 if frame:
自动记录:调用文件名、代码行号、所在函数 frame = frame.f_back.f_back
""" file_name = os.path.basename(frame.f_code.co_filename) if frame else "unknown"
# 回溯栈帧,获取真正调用 algo_log 的代码位置
# f_back(1) -> algo_log 自身f_back(2) -> 业务调用处
frame = inspect.currentframe().f_back.f_back
if not frame:
file_name = "unknown"
else:
file_name = os.path.basename(frame.f_code.co_filename)
logger = init_module_logger(file_name) logger = init_module_logger(file_name)
# 单次日志去重
if record_once: if record_once:
log_key = f"{level.upper()}_{content}" log_key = f"{level.upper()}_{content}"
if log_key in log_once_cache: if log_key in log_once_cache:
return return
log_once_cache.add(log_key) log_once_cache.add(log_key)
# 日志级别分发
level_upper = level.upper() level_upper = level.upper()
log_map = { log_func_map = {
"DEBUG": logger.debug, "DEBUG": logger.debug,
"INFO": logger.info,
"WARNING": logger.warning, "WARNING": logger.warning,
"ERROR": logger.error, "ERROR": logger.error,
"FATAL": logger.fatal, "FATAL": logger.fatal
"INFO": logger.info
} }
log_func = log_map.get(level_upper, logger.info) log_func = log_func_map.get(level_upper, logger.info)
log_func(content) log_func(content)

View File

@@ -28,7 +28,6 @@ echo "输出目录:${OUT_DIR}"
python -m nuitka \ python -m nuitka \
--standalone \ --standalone \
--msvc=latest \ --msvc=latest \
--windows-console-mode=disable \
--module-parameter=torch-disable-jit=yes \ --module-parameter=torch-disable-jit=yes \
--enable-plugin=no-qt \ --enable-plugin=no-qt \
--include-package=numpy \ --include-package=numpy \

View File

@@ -1,7 +1,7 @@
import matplotlib # import matplotlib
matplotlib.use('Agg') # matplotlib.use('Agg')
import argparse # import argparse
import sys # import sys
import time import time
from Decoder import Decoder_main from Decoder import Decoder_main
from PubLibrary.RunOnce import is_program_running from PubLibrary.RunOnce import is_program_running