Compare commits

..

17 Commits

Author SHA1 Message Date
c27e250fad update log config 2026-06-13 19:47:27 +08:00
66c0b71b89 update ip 2026-06-13 17:35:46 +08:00
5c7b73b7a4 add log 2026-06-13 16:49:29 +08:00
9690971f43 update 2026-06-13 11:54:58 +08:00
5a5f103ef6 update log 2026-06-13 10:06:29 +08:00
b31bb18dfe update 2026-06-12 15:21:47 +08:00
38480a2ca3 remove release 2026-06-12 14:30:11 +08:00
62e7cab5be add sleep if open impedence 2026-06-12 13:56:48 +08:00
Ivey Song
b26ae2ce3c beta psd 独立线程 2026-06-12 11:33:48 +08:00
5488626112 update 2026-06-11 14:29:43 +08:00
d59b0f695f realeas v1 2026-06-11 11:55:35 +08:00
0570d41439 bug fix 2026-06-11 11:06:59 +08:00
4574798d86 release v1 2026-06-11 09:21:57 +08:00
d480107b37 update 2026-06-11 08:11:29 +08:00
2d70fc9956 update log path 2026-06-11 08:04:08 +08:00
Ivey Song
1bbe84eb56 beta psd return 2026-06-10 17:55:43 +08:00
Ivey Song
f21367bc20 betapsd 回调 2026-06-10 17:55:43 +08:00
14 changed files with 371 additions and 148 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
__pycache__/
# Distribution / packaging
release/
build/
dist/
dist_nuitka/

View File

@@ -62,6 +62,8 @@ class Decoder_main(threading.Thread):
# 注册滤波结果回调(示例:打印数据形状)
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
# 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机
self.sliding_filter.set_beta_broadcast_callback(lambda v: self.zmqServer.broadcast_message('beta_psd', v))
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
# data: (chans, samples)
@@ -155,8 +157,8 @@ class Decoder_main(threading.Thread):
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
def parameter_init(self,bandPass_low,bandPass_high):
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息 ssmvep [50, 550]
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch ssmevep [50, 575]
self.trainData = [] #训练数据
self.trainLabel = [] #训练标签
self.plotData = [] #报告分析数据
@@ -204,6 +206,9 @@ class Decoder_main(threading.Thread):
self.zmqServer.state_mode = 'rest'
try:
if self.zmqServer.open_Impedance:
time.sleep(0.005)
continue
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
self.decoder_SSVEP()
elif self.decoder_class == 'ssmvep':
@@ -213,7 +218,7 @@ class Decoder_main(threading.Thread):
else:
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005)
continue;
continue
self.zmqServer.paradigmBuffer.getData(25)
except Exception as e:
algo_log(f"Decoder Loop Error: {e}")
@@ -231,7 +236,7 @@ class Decoder_main(threading.Thread):
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
return
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
# algo_log(f"SSVEP取出的{data.shape}, data = {data[:20]}", level="DEBUG")
# algo_log(f"SSVEP取出的{data.shape}, data = {data[:, :10]}", level="DEBUG")
data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
@@ -284,6 +289,7 @@ class Decoder_main(threading.Thread):
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG")
else:
time.sleep(0.0001)
return
@@ -423,7 +429,7 @@ class Decoder_main(threading.Thread):
y_pred = torch.max(Cls, 1)[1]
self.plotLabel.append(int(y_pred.item()))
algo_log(f"MI运动意图识别: {y_pred}")
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
self.zmqServer.broadcast_message('result', int(y_pred.item()))
end = time.time()
algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态

View File

@@ -318,11 +318,7 @@ class ExP():
train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
algo_log('Epoch:', e,
' Train loss: %.6f' % loss.detach().cpu().numpy(),
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
' Train accuracy %.6f' % train_acc,
' Test accuracy is %.6f' % acc, level="debug")
algo_log(f"Epoch = {e}, Train loss = {loss.detach().cpu().numpy():.6f}, Test loss = {loss_test.detach().cpu().numpy():.6f}, Train accuracy = {train_acc:.6f}, Test accuracy = {acc:.6f}", level="debug")
self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1
@@ -335,8 +331,8 @@ class ExP():
torch.save(self.model, model_path)
averAcc = averAcc / num
algo_log('The average accuracy is:', averAcc, level="debug")
algo_log('The best accuracy is:', bestAcc, level="debug")
algo_log(f"The average accuracy is: {averAcc}", level="debug")
algo_log(f"The best accuracy is: {bestAcc}", level="debug")
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
@@ -366,12 +362,13 @@ def onlineTrain(data_queue,result_queue):
data = data_queue.get(timeout=30)
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
exp = ExP(n_chan)
algo_log('训练参数: ',np.shape(all_data),np.shape(all_label),model_path, level="debug")
algo_log(f"训练参数: {np.shape(all_data)}, {np.shape(all_label)}, {model_path}", level="debug")
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
algo_log(f"THE BEST ACCURACY IS {str(bestAcc)}", level="debug")
endtime = datetime.datetime.now()
algo_log('train duration: ',str(endtime - starttime), level="debug")
algo_log(f"train duration: {endtime - starttime}", level="debug")
# 将模型或参数传回
result_queue.put({
@@ -387,7 +384,7 @@ def offlineTrain(all_data,all_label,modelPath):
# seed_n = np.random.randint(2025)
seed_n = 1877
algo_log('seed is ' + str(seed_n), level="debug")
algo_log(f"seed is {seed_n}", level="debug")
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
@@ -400,7 +397,7 @@ def offlineTrain(all_data,all_label,modelPath):
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
endtime = datetime.datetime.now()
algo_log('train duration: ',str(endtime - starttime), level="debug")
algo_log(f"train duration: {endtime - starttime}", level="debug")

View File

@@ -13,6 +13,13 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and
6. decoder class切换问题
7. decoder_class切换时数据重置、各类参数重置
# realease log
- 2026年6月11日11:29:17 打包第一版包名runDecoder.dist_v0.0.0_beta_20260611.7z
- 2026年6月11日12:00:00 打包第二版包名runDecoder.dist_v0.0.0_beta_20260611.7z
- 修复上位机先发decoder_class, 后发open_impedence 带来decoder_main thread 阻塞问题
- 2026年6月12日15:05:47 runDecoder.dist_v0.0.2_beta_20260612
- 优化filter读数精度
# 常用命令
source activate 3in1Py310
@@ -22,7 +29,15 @@ python ZeroMQClient_mock.py
python filter_test.py
python upperHost_stimmock/MI_headless.py
# 打包命令
./nuitka_3in1_package.sh
# 遗留问题
# TODO
1. mvep是否要把list freq 开放到config
2. 滤波器参数 放到config文件
# debug log
## MI
Epoch采集完成|收到命令: {'method': 'train'|取出的
收到命令: {'method': 'train'|收到命令: {'method': 'train'|收到命令: {'method': 'predict'|事件检测到

View File

@@ -19,10 +19,10 @@ class FbccaDw:
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
algo_log('******************************************', level="debug")
algo_log('parameter list',level="debug")
algo_log('target:', num_target, level="debug")
algo_log('number of filter bank:', num_filter, level="debug")
algo_log('parameter:', parameter, level="debug")
algo_log('width:', width, level="debug")
algo_log(f"target: {num_target}", level="debug")
algo_log(f"number of filter bank: {num_filter}", level="debug")
algo_log(f"parameter: {parameter}", level="debug")
algo_log(f"width: {width}", level="debug")
self.phase = 0
self.bandWidth = width
self.winNum = winNum

View File

@@ -20,7 +20,7 @@ class Beta_Calculate():
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
# print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
return beta_psd, alpha_psd, theta_psd

View File

@@ -5,8 +5,13 @@
import numpy as np
import time
import threading
import queue
from scipy import signal
from logs.log import algo_log
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Tools.beta_calculate import Beta_Calculate
class FilterRingBuffer:
def __init__(self, n_chan, n_points):
@@ -89,7 +94,74 @@ class FilterRingBuffer:
self.has_new_data = False # 重置时清空新数据标记
# -----------------------------------------------------------------------------
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现
# 2. 独立 Beta PSD 计算线程(避免阻塞滤波主循环的 200ms 定时
# -----------------------------------------------------------------------------
class BetaPsdCalculator(threading.Thread):
"""独立的 Beta PSD 计算线程,使用队列与滤波主线程解耦"""
def __init__(self, fs=250, window_size=750):
super().__init__(daemon=True)
self.fs = fs
self.window_size = window_size
self._beta_calc = Beta_Calculate(Threshold_value_low=0, Threshold_value_high=0, fs=fs)
self._input_queue = queue.Queue(maxsize=2)
self._running = threading.Event()
self._running.set()
self._latest_beta = None
self._beta_lock = threading.Lock()
self.beta_broadcast_callback = None
def push_data(self, data):
"""供外部调用的线程安全数据推送接口"""
try:
self._input_queue.put_nowait(data)
except queue.Full:
try:
self._input_queue.get_nowait()
except queue.Empty:
pass
try:
self._input_queue.put_nowait(data)
except queue.Full:
pass
def get_latest_beta(self):
"""获取最新的 beta 值(线程安全)"""
with self._beta_lock:
return self._latest_beta
def run(self):
while self._running.is_set():
try:
data = self._input_queue.get(timeout=1.5)
if data is None:
break
try:
beta_psd, _, _ = self._beta_calc.calculate_all(
data, fs=self.fs, nperseg=min(self.window_size, data.shape[1])
)
with self._beta_lock:
self._latest_beta = round(float(beta_psd), 3)
if self.beta_broadcast_callback is not None:
self.beta_broadcast_callback(self._latest_beta)
except Exception as e:
algo_log(f"Beta PSD 计算异常: {e}", level='error')
except queue.Empty:
pass
def stop(self):
"""停止计算线程"""
self._running.clear()
try:
self._input_queue.put_nowait(None)
except queue.Full:
pass
if self.is_alive():
self.join(timeout=2)
# -----------------------------------------------------------------------------
# 3. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
# -----------------------------------------------------------------------------
class SlidingFilter(threading.Thread):
def __init__(
@@ -118,14 +190,32 @@ class SlidingFilter(threading.Thread):
# 滤波结果回调(外部可注册,获取滤波后的数据)
self.filter_result_callback = None
# beta 每秒触发计数200ms步长5次 = 1s
self._beta_step_counter = 0
self._beta_steps_per_second = max(1, int(round(1.0 / step_sec))) # 5
self.slide_window = None # 滑动窗口缓存 (n_chan, window_size)
self.slide_ready = False # 窗口是否已填满初始数据
# 预计算滤波器系数(仅执行一次)
self._init_filters()
# 独立的 Beta 计算线程(避免阻塞滤波主循环)
self._beta_thread = BetaPsdCalculator(fs=srate, window_size=self.window_size)
def start(self):
"""同时启动 Beta 计算线程和滤波主线程"""
self._beta_thread.start()
super().start()
def set_beta_broadcast_callback(self, callback):
"""注册 Beta PSD 广播回调函数"""
self._beta_thread.beta_broadcast_callback = callback
def _init_filters(self):
"""预计算所有滤波器系数(仅执行一次)"""
# 50Hz工频陷波Q=30工业标准
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
# 8~30Hz带通FIR65阶线性相位
# 0.5~45Hz带通FIR65阶线性相位
self.b_bp = signal.firwin(
numtaps=65,
cutoff=[0.5/(self.srate/2), 45/(self.srate/2)],
@@ -135,7 +225,7 @@ class SlidingFilter(threading.Thread):
self.a_bp = np.array([1.0])
def _filter_window_data(self, window_data):
"""对3秒窗口数据执行滤波返回无边界效应的200ms数据"""
"""对3秒窗口数据执行滤波返回 (无边界效应的200ms数据, 完整3s滤波数据)"""
# 零相位滤波(无延迟,无边界效应)
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
@@ -146,39 +236,64 @@ class SlidingFilter(threading.Thread):
start_idx = self.window_size - 2 * self.step_size
end_idx = self.window_size - self.step_size
output_data = filtered[:, start_idx:end_idx].copy()
return output_data
return output_data, filtered
def run(self):
"""线程主逻辑精确200ms触发一次滤波"""
interval = self.step_sec # 200ms = 0.2
next_run_time = time.perf_counter()
while self.running.is_set():
# 1. 精确定时等待
current_time = time.perf_counter()
if current_time < next_run_time:
time.sleep(next_run_time - current_time)
next_run_time += interval
else:
algo_log("滤波耗时超过200ms定时偏移", level='debug')
next_run_time = time.perf_counter() + interval
interval = self.step_sec # 0.2s
# 以启动时刻为绝对时间基准(核心改动)
base_time = time.perf_counter()
frame_count = 0 # 帧计数器,用于对齐时序
# ========== 新增核心判断:无新数据则直接跳过 ==========
while self.running.is_set():
# 计算理论执行时刻:严格按帧序号 × 步长
expect_time = base_time + frame_count * interval
current_time = time.perf_counter()
# 精确定时等待
if current_time < expect_time:
time.sleep(expect_time - current_time)
else:
# 处理超时:仅告警,不重置基准(防止累积偏移)
algo_log(f"滤波任务超时,偏移 {(current_time - expect_time)*1000:.1f} ms", level='debug')
frame_count += 1 # 帧序号自增,保证周期绝对稳定
if not self.ring_buffer.check_and_clear_new_data():
# 无新数据,不执行滤波、不发送数据
continue
# 2. 有新数据,才执行原有滤波逻辑
# ========== 原有滤波逻辑 ==========
try:
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
if window_data is None:
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}", level='debug')
if not self.slide_ready:
# 阶段1首次填满3s初始窗口
full_data = self.ring_buffer.get_latest_n_points(self.window_size)
if full_data is None:
algo_log("初始窗口数据不足", level='debug')
continue
self.slide_window = full_data
self.slide_ready = True
else:
# 阶段2正常滑动 → 取最新50个新点增量拼接
new_step_data = self.ring_buffer.get_latest_n_points(self.step_size)
if new_step_data is None:
algo_log("滑动步长数据不足", level='debug')
continue
# 增量滑动丢弃前50点拼接新50点标准滑动窗口
self.slide_window = np.hstack([
self.slide_window[:, self.step_size:],
new_step_data
])
filtered_data = self._filter_window_data(window_data)
# algo_log(f"滤波后{filtered_data.shape}数据", level='debug')
filtered_data, filtered_full = self._filter_window_data(self.slide_window[:64, :])
# Beta PSD 每秒计算一次
self._beta_step_counter += 1
if self._beta_step_counter >= self._beta_steps_per_second:
self._beta_step_counter = 0
self._beta_thread.push_data(filtered_full[:2, :])
if self.filter_result_callback is not None:
self.filter_result_callback(filtered_data[:64, :])
self.filter_result_callback(filtered_data)
except Exception as e:
algo_log(f"滤波执行异常: {e}", level='error')
@@ -187,17 +302,11 @@ class SlidingFilter(threading.Thread):
self.filter_result_callback = callback
def stop(self):
"""停止滤波线程(安全版)"""
# 1. 先设置停止标志Event.clear()是线程安全的)
"""停止滤波线程和 Beta 计算线程"""
self._beta_thread.stop()
self.running.clear()
# 2. 核心修复只有线程已启动且正在运行时才调用join
if self.is_alive():
# 等待线程正常退出最多1秒
self.join(timeout=1)
# 超时未退出时打印警告,便于排查问题
if self.is_alive():
algo_log("警告滤波线程在1秒内未正常退出可能存在阻塞操作", level="WARNING")
# 3. 无论线程是否启动,都打印停止日志
algo_log("滤波线程已停止")

View File

@@ -152,6 +152,7 @@ class zmqServer(threading.Thread):
msg = {'method': method, 'params': params}
msg_bytes = json.dumps(msg).encode('utf-8')
if msg['method'] != 'beta_psd':
algo_log(f"发送命令结果: {msg}", level="DEBUG")
# 广播到所有命令客户端
@@ -179,7 +180,7 @@ class zmqServer(threading.Thread):
# 转置为上位机需要的[50, 通道数]格式
filtered_data = filtered_data.T.astype(np.float64)
send_buf = filtered_data.tobytes()
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
# algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
self.data_send_queue.put(send_buf)
def _process_data_send_queue(self):
@@ -225,6 +226,9 @@ class zmqServer(threading.Thread):
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
return
except Exception as e:
algo_log(f"_handle_cmd_message exception: {e}", level="ERROR")
return
algo_log(f"收到命令: {message}", level="INFO")
method = message.get("method")
@@ -270,6 +274,22 @@ class zmqServer(threading.Thread):
elif params == 2: #停止解码
self.IsExitApp = True
self.running = False
resp = {
"method": "predict_response",
"params": {
"code": 200,
"message": "ok"
}
}
try:
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
algo_log(f"predict 命令已即时回复客户端 {ident}", level="DEBUG")
except Exception as e:
algo_log(f"predict 命令回复失败: {e}", level="ERROR")
return
elif method == "rest":
self.state_mode = 'rest'
elif method == "impedance":
@@ -353,8 +373,8 @@ class zmqServer(threading.Thread):
def detect_event(self, samples):
self.pack_contain_event = False
# 第65通道为事件通道
event = int(samples[-2][0])
# for idx, event in enumerate(events):
events = np.array(samples[-2], dtype=np.int32).tolist()
for idx, event in enumerate(events):
if event in self.events:
new_key = "".join(
[
@@ -368,8 +388,8 @@ class zmqServer(threading.Thread):
self.count_events[new_key] = self.latency + 1
else:
self.count_events[new_key] = self.train_latency + 1
self.event_inner_idx = self.device_info['frame_points'] - 1
# algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG")
self.event_inner_idx = idx
algo_log(f"事件检测到: {events},索引: {idx}", level="DEBUG")
self.pack_contain_event = True
# 倒计时并清理过期事件
@@ -405,13 +425,18 @@ class zmqServer(threading.Thread):
frames = self.cmd_socket.recv_multipart()
self._handle_cmd_message(frames)
# 处理8100数据端口消息
# 处理8100数据端口消息(排空积压,消除标签延迟)
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
frames = self.data_socket.recv_multipart()
while True:
try:
frames = self.data_socket.recv_multipart(zmq.NOBLOCK)
self._handle_data_message(frames)
except zmq.Again:
break
except Exception as e:
algo_log(f"服务器主循环异常: {e}", level="ERROR")
algo_log(f"服务器主循环异常: {str(e)}", level="ERROR")
return
finally:
self.running = False
# 优雅关闭所有资源

View File

@@ -18,11 +18,23 @@ Upper_Port = 8088
Decoder_Host = 127.0.0.1
Decoder_Port = 8099
Serial_port = COM44
algo_log_level = DEBUG
console_output = 1
save_train_data = 0
zmqServer_host = 127.0.0.1
[algo_log]
# ========== 文件日志配置 ==========
file_log_enable = true
file_log_level = DEBUG
log_path = exe
retention_days = 3
# ========== 控制台/黑框配置 ==========
console_enable = true
console_show_window = true
console_log_level = DEBUG
; 64 导设备配置
[device_type_1]
sample_rate = 250

View File

@@ -56,7 +56,7 @@ EMPTY_FRAME = b""
# 仿真信号配置
TARGET_CHANNEL = 0
SIGNAL_FREQ_LIST = [3, 13]
SIGNAL_FREQ_LIST = [13]
SIGNAL_AMP = 1.8
NOISE_GAUSSIAN_AMP = 0.4
NOISE_POWER50_AMP = 0.3
@@ -128,8 +128,8 @@ def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
sig = 0.0
for freq in SIGNAL_FREQ_LIST:
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
# sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
# sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
data[:, ch] = sig
# 事件通道、保留通道

View File

@@ -1,114 +1,172 @@
import os
import sys
from pathlib import Path
from datetime import datetime, timedelta
import logging
from logging.handlers import RotatingFileHandler
import inspect
try:
import win32gui
import win32con
WIN32_AVAILABLE = True
except ImportError:
WIN32_AVAILABLE = False
from PubLibrary.InifileHelper import IniRead
# 全局配置
console_output = IniRead('system', 'console_output', '1')
log_level = IniRead('system', 'algo_log_level', 'INFO')
# ===================== 新增:获取 EXE 同级目录 =====================
def get_app_root():
"""获取 runDecoder.exe 所在的真实根目录(兼容 onefile / standalone"""
if getattr(sys, 'frozen', False):
# Nuitka / PyInstaller 打包后走这里
app_path = sys.executable
else:
# 本地源码运行时,取当前脚本目录
app_path = os.path.abspath(__file__)
return os.path.dirname(app_path)
# 程序根目录exe 同级)
APP_ROOT = Path(get_app_root())
# 日志文件夹名exe 同级下 logs 目录
DEFAULT_LOG_DIR = APP_ROOT / "logs"
# ===================== 读取 [algo_log] 配置 =====================
# 文件日志
FILE_LOG_ENABLE = IniRead("algo_log", "file_log_enable", "true").lower() == "true"
FILE_LOG_LEVEL = IniRead("algo_log", "file_log_level", "DEBUG").upper()
# 优先级:配置文件 > 默认exe同级logs
CFG_LOG_PATH = IniRead("algo_log", "log_path", "").strip()
if CFG_LOG_PATH == "exe":
LOG_DIR = DEFAULT_LOG_DIR
else:
LOG_DIR = Path(CFG_LOG_PATH)
LOG_RETENTION_DAYS = int(IniRead("algo_log", "retention_days", 3))
# 控制台日志 + 黑框控制
CONSOLE_ENABLE = IniRead("algo_log", "console_enable", "true").lower() == "true"
CONSOLE_SHOW_WINDOW = IniRead("algo_log", "console_show_window", "true").lower() == "true"
CONSOLE_LOG_LEVEL = IniRead("algo_log", "console_log_level", "INFO").upper()
# ===================== 全局常量与缓存 =====================
log_once_cache = set()
logger_cache = {}
LOG_RETENTION_DAYS = 3
LOG_DIR = './logs/'
LOG_FILE_PREFIX = 'algo_log_'
# 确保日志目录存在
LOG_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR_STR = str(LOG_DIR) + "\\"
# 日志格式:时间 - 日志器名 - 级别 - 文件名:行号 - 函数名 - 日志内容
# 日志格式
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
# 日志级别映射
LEVEL_MAP = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"FATAL": logging.FATAL
}
FILE_LOG_LEVEL_INT = LEVEL_MAP.get(FILE_LOG_LEVEL, logging.INFO)
CONSOLE_LOG_LEVEL_INT = LEVEL_MAP.get(CONSOLE_LOG_LEVEL, logging.INFO)
def clean_old_logs():
"""清理超过指定天数的旧日志文件"""
# ===================== Windows 控制台黑框显示/隐藏 =====================
def control_console_window():
if not sys.platform.startswith("win") or not WIN32_AVAILABLE:
return
try:
if not os.path.exists(LOG_DIR):
hwnd = win32gui.GetForegroundWindow()
if CONSOLE_SHOW_WINDOW:
win32gui.ShowWindow(hwnd, win32con.SW_SHOW)
else:
win32gui.ShowWindow(hwnd, win32con.SW_HIDE)
except Exception:
pass
control_console_window()
# ===================== 清理过期日志 =====================
def clean_old_logs():
try:
if not LOG_DIR.exists():
return
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
for filename in os.listdir(LOG_DIR):
if not filename.startswith(LOG_FILE_PREFIX) or not filename.endswith('.log'):
if not (filename.startswith(LOG_FILE_PREFIX) and filename.endswith('.log')):
continue
date_str = filename[len(LOG_FILE_PREFIX):-4]
try:
file_date = datetime.strptime(date_str, '%Y-%m-%d')
if file_date < expire_date:
file_path = os.path.join(LOG_DIR, filename)
file_path = LOG_DIR / filename
os.remove(file_path)
print(f"清理过期日志: {file_path}")
except ValueError:
continue
except Exception as e:
print(f"清理旧日志异常: {str(e)}")
except Exception:
pass
# ===================== 初始化日志器 =====================
def init_module_logger(logger_name):
"""初始化日志器 + 清理旧日志"""
os.makedirs(LOG_DIR, exist_ok=True)
clean_old_logs()
current_date = datetime.now().strftime("%Y-%m-%d")
log_file = os.path.join(LOG_DIR, f"{LOG_FILE_PREFIX}{current_date}.log")
if logger_name in logger_cache:
return logger_cache[logger_name]
clean_old_logs()
logger = logging.getLogger(logger_name)
logger.setLevel(log_level)
logger.setLevel(logging.DEBUG)
if logger.handlers:
logger_cache[logger_name] = logger
return logger
# 文件输出处理器
formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
# 文件日志
if FILE_LOG_ENABLE:
current_date = datetime.now().strftime("%Y-%m-%d")
log_file = LOG_DIR / f"{LOG_FILE_PREFIX}{current_date}.log"
file_handler = RotatingFileHandler(
log_file,
maxBytes=10 * 1024 * 1024,
backupCount=10,
encoding='utf-8'
)
formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
file_handler.setFormatter(formatter)
file_handler.setLevel(FILE_LOG_LEVEL_INT)
logger.addHandler(file_handler)
# 控制台输出
if console_output:
console_handler = logging.StreamHandler()
# 控制台日志
if CONSOLE_ENABLE:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
console_handler.setLevel(CONSOLE_LOG_LEVEL_INT)
logger.addHandler(console_handler)
logger_cache[logger_name] = logger
return logger
# ===================== 对外日志入口函数 =====================
def algo_log(content, level="INFO", record_once=False):
"""
日志入口函数
自动记录:调用文件名、代码行号、所在函数
"""
# 回溯栈帧,获取真正调用 algo_log 的代码位置
# f_back(1) -> algo_log 自身f_back(2) -> 业务调用处
frame = inspect.currentframe().f_back.f_back
if not frame:
file_name = "unknown"
else:
file_name = os.path.basename(frame.f_code.co_filename)
frame = inspect.currentframe()
if frame:
frame = frame.f_back.f_back
file_name = os.path.basename(frame.f_code.co_filename) if frame else "unknown"
logger = init_module_logger(file_name)
# 单次日志去重
if record_once:
log_key = f"{level.upper()}_{content}"
if log_key in log_once_cache:
return
log_once_cache.add(log_key)
# 日志级别分发
level_upper = level.upper()
log_map = {
log_func_map = {
"DEBUG": logger.debug,
"INFO": logger.info,
"WARNING": logger.warning,
"ERROR": logger.error,
"FATAL": logger.fatal,
"INFO": logger.info
"FATAL": logger.fatal
}
log_func = log_map.get(level_upper, logger.info)
log_func = log_func_map.get(level_upper, logger.info)
log_func(content)

View File

@@ -28,7 +28,6 @@ echo "输出目录:${OUT_DIR}"
python -m nuitka \
--standalone \
--msvc=latest \
--windows-console-mode=disable \
--module-parameter=torch-disable-jit=yes \
--enable-plugin=no-qt \
--include-package=numpy \

View File

@@ -1,7 +1,7 @@
import matplotlib
matplotlib.use('Agg')
import argparse
import sys
# import matplotlib
# matplotlib.use('Agg')
# import argparse
# import sys
import time
from Decoder import Decoder_main
from PubLibrary.RunOnce import is_program_running

View File

@@ -170,6 +170,7 @@ def run_headless():
time.sleep(1) # 等待连接建立
client.send_data('decoderClass', 'mi')
time.sleep(4) # 等待 zmqServer 排空启动积压包datamock 提前连接会积压 ~3s 数据)
# MI_IntervalEpoch = [0.5, 4.5]trial时长 = 4.5-0.5 = 4.0s
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
@@ -222,7 +223,7 @@ def run_headless():
time.sleep(0.5) # ding 提示后等待
client.send_data('train', 0)
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
trained += 1
client.send_data('rest', 0)
@@ -231,7 +232,7 @@ def run_headless():
# 空闲态样本采集train 1label=2
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
client.send_data('train', 1)
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
trained += 1
client.send_data('rest', 0)