Compare commits
17 Commits
ba4ae92647
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| c27e250fad | |||
| 66c0b71b89 | |||
| 5c7b73b7a4 | |||
| 9690971f43 | |||
| 5a5f103ef6 | |||
| b31bb18dfe | |||
| 38480a2ca3 | |||
| 62e7cab5be | |||
|
|
b26ae2ce3c | ||
| 5488626112 | |||
| d59b0f695f | |||
| 0570d41439 | |||
| 4574798d86 | |||
| d480107b37 | |||
| 2d70fc9956 | |||
|
|
1bbe84eb56 | ||
|
|
f21367bc20 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
__pycache__/
|
||||
|
||||
# Distribution / packaging
|
||||
release/
|
||||
build/
|
||||
dist/
|
||||
dist_nuitka/
|
||||
|
||||
16
Decoder.py
16
Decoder.py
@@ -62,6 +62,8 @@ class Decoder_main(threading.Thread):
|
||||
|
||||
# 注册滤波结果回调(示例:打印数据形状)
|
||||
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
|
||||
# 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机
|
||||
self.sliding_filter.set_beta_broadcast_callback(lambda v: self.zmqServer.broadcast_message('beta_psd', v))
|
||||
|
||||
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
||||
# data: (chans, samples)
|
||||
@@ -155,8 +157,8 @@ class Decoder_main(threading.Thread):
|
||||
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
|
||||
|
||||
def parameter_init(self,bandPass_low,bandPass_high):
|
||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
|
||||
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
|
||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息 ssmvep [50, 550]
|
||||
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch ssmevep [50, 575]
|
||||
self.trainData = [] #训练数据
|
||||
self.trainLabel = [] #训练标签
|
||||
self.plotData = [] #报告分析数据
|
||||
@@ -204,6 +206,9 @@ class Decoder_main(threading.Thread):
|
||||
self.zmqServer.state_mode = 'rest'
|
||||
|
||||
try:
|
||||
if self.zmqServer.open_Impedance:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||
self.decoder_SSVEP()
|
||||
elif self.decoder_class == 'ssmvep':
|
||||
@@ -213,7 +218,7 @@ class Decoder_main(threading.Thread):
|
||||
else:
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
continue;
|
||||
continue
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
except Exception as e:
|
||||
algo_log(f"Decoder Loop Error: {e}")
|
||||
@@ -231,7 +236,7 @@ class Decoder_main(threading.Thread):
|
||||
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
|
||||
return
|
||||
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
|
||||
# algo_log(f"SSVEP取出的:{data.shape}, data = {data[:20]}", level="DEBUG")
|
||||
# algo_log(f"SSVEP取出的:{data.shape}, data = {data[:, :10]}", level="DEBUG")
|
||||
data = data[:self.n_chan, :]
|
||||
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||
@@ -284,6 +289,7 @@ class Decoder_main(threading.Thread):
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||
else:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
@@ -423,7 +429,7 @@ class Decoder_main(threading.Thread):
|
||||
y_pred = torch.max(Cls, 1)[1]
|
||||
self.plotLabel.append(int(y_pred.item()))
|
||||
algo_log(f"MI运动意图识别: {y_pred}")
|
||||
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
|
||||
self.zmqServer.broadcast_message('result', int(y_pred.item()))
|
||||
end = time.time()
|
||||
algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。')
|
||||
else: # 休息状态
|
||||
|
||||
@@ -318,11 +318,7 @@ class ExP():
|
||||
train_pred = torch.max(outputs, 1)[1]
|
||||
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
||||
|
||||
algo_log('Epoch:', e,
|
||||
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
||||
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
||||
' Train accuracy %.6f' % train_acc,
|
||||
' Test accuracy is %.6f' % acc, level="debug")
|
||||
algo_log(f"Epoch = {e}, Train loss = {loss.detach().cpu().numpy():.6f}, Test loss = {loss_test.detach().cpu().numpy():.6f}, Train accuracy = {train_acc:.6f}, Test accuracy = {acc:.6f}", level="debug")
|
||||
|
||||
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
||||
num = num + 1
|
||||
@@ -335,8 +331,8 @@ class ExP():
|
||||
|
||||
torch.save(self.model, model_path)
|
||||
averAcc = averAcc / num
|
||||
algo_log('The average accuracy is:', averAcc, level="debug")
|
||||
algo_log('The best accuracy is:', bestAcc, level="debug")
|
||||
algo_log(f"The average accuracy is: {averAcc}", level="debug")
|
||||
algo_log(f"The best accuracy is: {bestAcc}", level="debug")
|
||||
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
||||
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
||||
|
||||
@@ -366,12 +362,13 @@ def onlineTrain(data_queue,result_queue):
|
||||
data = data_queue.get(timeout=30)
|
||||
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
||||
exp = ExP(n_chan)
|
||||
algo_log('训练参数: ',np.shape(all_data),np.shape(all_label),model_path, level="debug")
|
||||
algo_log(f"训练参数: {np.shape(all_data)}, {np.shape(all_label)}, {model_path}", level="debug")
|
||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
||||
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
|
||||
algo_log(f"THE BEST ACCURACY IS {str(bestAcc)}", level="debug")
|
||||
|
||||
endtime = datetime.datetime.now()
|
||||
algo_log('train duration: ',str(endtime - starttime), level="debug")
|
||||
algo_log(f"train duration: {endtime - starttime}", level="debug")
|
||||
|
||||
|
||||
# 将模型或参数传回
|
||||
result_queue.put({
|
||||
@@ -387,7 +384,7 @@ def offlineTrain(all_data,all_label,modelPath):
|
||||
|
||||
# seed_n = np.random.randint(2025)
|
||||
seed_n = 1877
|
||||
algo_log('seed is ' + str(seed_n), level="debug")
|
||||
algo_log(f"seed is {seed_n}", level="debug")
|
||||
random.seed(seed_n)
|
||||
np.random.seed(seed_n)
|
||||
torch.manual_seed(seed_n)
|
||||
@@ -400,7 +397,7 @@ def offlineTrain(all_data,all_label,modelPath):
|
||||
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
|
||||
|
||||
endtime = datetime.datetime.now()
|
||||
algo_log('train duration: ',str(endtime - starttime), level="debug")
|
||||
algo_log(f"train duration: {endtime - starttime}", level="debug")
|
||||
|
||||
|
||||
|
||||
|
||||
19
README.md
19
README.md
@@ -13,6 +13,13 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and
|
||||
6. decoder class切换问题
|
||||
7. decoder_class切换时,数据重置、各类参数重置
|
||||
|
||||
# realease log
|
||||
- 2026年6月11日11:29:17 打包第一版,包名runDecoder.dist_v0.0.0_beta_20260611.7z
|
||||
- 2026年6月11日12:00:00 打包第二版,包名runDecoder.dist_v0.0.0_beta_20260611.7z
|
||||
- 修复上位机先发decoder_class, 后发open_impedence 带来decoder_main thread 阻塞问题
|
||||
|
||||
- 2026年6月12日15:05:47 runDecoder.dist_v0.0.2_beta_20260612
|
||||
- 优化filter读数精度
|
||||
|
||||
# 常用命令
|
||||
source activate 3in1Py310
|
||||
@@ -22,7 +29,15 @@ python ZeroMQClient_mock.py
|
||||
python filter_test.py
|
||||
python upperHost_stimmock/MI_headless.py
|
||||
|
||||
# 打包命令
|
||||
./nuitka_3in1_package.sh
|
||||
|
||||
|
||||
# 遗留问题
|
||||
# TODO
|
||||
1. mvep是否要把list freq 开放到config
|
||||
2. 滤波器参数 放到config文件
|
||||
|
||||
# debug log
|
||||
## MI
|
||||
Epoch采集完成|收到命令: {'method': 'train'|取出的
|
||||
|
||||
收到命令: {'method': 'train'|收到命令: {'method': 'train'|收到命令: {'method': 'predict'|事件检测到
|
||||
@@ -18,11 +18,11 @@ from logs.log import algo_log
|
||||
class FbccaDw:
|
||||
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
||||
algo_log('******************************************', level="debug")
|
||||
algo_log('parameter list', level="debug")
|
||||
algo_log('target:', num_target, level="debug")
|
||||
algo_log('number of filter bank:', num_filter, level="debug")
|
||||
algo_log('parameter:', parameter, level="debug")
|
||||
algo_log('width:', width, level="debug")
|
||||
algo_log('parameter list',level="debug")
|
||||
algo_log(f"target: {num_target}", level="debug")
|
||||
algo_log(f"number of filter bank: {num_filter}", level="debug")
|
||||
algo_log(f"parameter: {parameter}", level="debug")
|
||||
algo_log(f"width: {width}", level="debug")
|
||||
self.phase = 0
|
||||
self.bandWidth = width
|
||||
self.winNum = winNum
|
||||
|
||||
@@ -20,7 +20,7 @@ class Beta_Calculate():
|
||||
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
|
||||
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
|
||||
|
||||
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
|
||||
# print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
|
||||
|
||||
return beta_psd, alpha_psd, theta_psd
|
||||
|
||||
|
||||
@@ -5,8 +5,13 @@
|
||||
import numpy as np
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
from scipy import signal
|
||||
from logs.log import algo_log
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from Tools.beta_calculate import Beta_Calculate
|
||||
|
||||
class FilterRingBuffer:
|
||||
def __init__(self, n_chan, n_points):
|
||||
@@ -89,7 +94,74 @@ class FilterRingBuffer:
|
||||
self.has_new_data = False # 重置时清空新数据标记
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||||
# 2. 独立 Beta PSD 计算线程(避免阻塞滤波主循环的 200ms 定时)
|
||||
# -----------------------------------------------------------------------------
|
||||
class BetaPsdCalculator(threading.Thread):
|
||||
"""独立的 Beta PSD 计算线程,使用队列与滤波主线程解耦"""
|
||||
|
||||
def __init__(self, fs=250, window_size=750):
|
||||
super().__init__(daemon=True)
|
||||
self.fs = fs
|
||||
self.window_size = window_size
|
||||
self._beta_calc = Beta_Calculate(Threshold_value_low=0, Threshold_value_high=0, fs=fs)
|
||||
self._input_queue = queue.Queue(maxsize=2)
|
||||
self._running = threading.Event()
|
||||
self._running.set()
|
||||
self._latest_beta = None
|
||||
self._beta_lock = threading.Lock()
|
||||
self.beta_broadcast_callback = None
|
||||
|
||||
def push_data(self, data):
|
||||
"""供外部调用的线程安全数据推送接口"""
|
||||
try:
|
||||
self._input_queue.put_nowait(data)
|
||||
except queue.Full:
|
||||
try:
|
||||
self._input_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
try:
|
||||
self._input_queue.put_nowait(data)
|
||||
except queue.Full:
|
||||
pass
|
||||
|
||||
def get_latest_beta(self):
|
||||
"""获取最新的 beta 值(线程安全)"""
|
||||
with self._beta_lock:
|
||||
return self._latest_beta
|
||||
|
||||
def run(self):
|
||||
while self._running.is_set():
|
||||
try:
|
||||
data = self._input_queue.get(timeout=1.5)
|
||||
if data is None:
|
||||
break
|
||||
try:
|
||||
beta_psd, _, _ = self._beta_calc.calculate_all(
|
||||
data, fs=self.fs, nperseg=min(self.window_size, data.shape[1])
|
||||
)
|
||||
with self._beta_lock:
|
||||
self._latest_beta = round(float(beta_psd), 3)
|
||||
if self.beta_broadcast_callback is not None:
|
||||
self.beta_broadcast_callback(self._latest_beta)
|
||||
except Exception as e:
|
||||
algo_log(f"Beta PSD 计算异常: {e}", level='error')
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""停止计算线程"""
|
||||
self._running.clear()
|
||||
try:
|
||||
self._input_queue.put_nowait(None)
|
||||
except queue.Full:
|
||||
pass
|
||||
if self.is_alive():
|
||||
self.join(timeout=2)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 3. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||||
# -----------------------------------------------------------------------------
|
||||
class SlidingFilter(threading.Thread):
|
||||
def __init__(
|
||||
@@ -118,14 +190,32 @@ class SlidingFilter(threading.Thread):
|
||||
# 滤波结果回调(外部可注册,获取滤波后的数据)
|
||||
self.filter_result_callback = None
|
||||
|
||||
# beta 每秒触发计数(200ms步长,5次 = 1s)
|
||||
self._beta_step_counter = 0
|
||||
self._beta_steps_per_second = max(1, int(round(1.0 / step_sec))) # 5
|
||||
|
||||
self.slide_window = None # 滑动窗口缓存 (n_chan, window_size)
|
||||
self.slide_ready = False # 窗口是否已填满初始数据
|
||||
# 预计算滤波器系数(仅执行一次)
|
||||
self._init_filters()
|
||||
|
||||
# 独立的 Beta 计算线程(避免阻塞滤波主循环)
|
||||
self._beta_thread = BetaPsdCalculator(fs=srate, window_size=self.window_size)
|
||||
|
||||
def start(self):
|
||||
"""同时启动 Beta 计算线程和滤波主线程"""
|
||||
self._beta_thread.start()
|
||||
super().start()
|
||||
|
||||
def set_beta_broadcast_callback(self, callback):
|
||||
"""注册 Beta PSD 广播回调函数"""
|
||||
self._beta_thread.beta_broadcast_callback = callback
|
||||
|
||||
def _init_filters(self):
|
||||
"""预计算所有滤波器系数(仅执行一次)"""
|
||||
# 50Hz工频陷波(Q=30,工业标准)
|
||||
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
|
||||
# 8~30Hz带通FIR(65阶,线性相位)
|
||||
# 0.5~45Hz带通FIR(65阶,线性相位)
|
||||
self.b_bp = signal.firwin(
|
||||
numtaps=65,
|
||||
cutoff=[0.5/(self.srate/2), 45/(self.srate/2)],
|
||||
@@ -135,7 +225,7 @@ class SlidingFilter(threading.Thread):
|
||||
self.a_bp = np.array([1.0])
|
||||
|
||||
def _filter_window_data(self, window_data):
|
||||
"""对3秒窗口数据执行滤波,返回无边界效应的200ms数据"""
|
||||
"""对3秒窗口数据执行滤波,返回 (无边界效应的200ms数据, 完整3s滤波数据)"""
|
||||
# 零相位滤波(无延迟,无边界效应)
|
||||
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
||||
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
||||
@@ -146,39 +236,64 @@ class SlidingFilter(threading.Thread):
|
||||
start_idx = self.window_size - 2 * self.step_size
|
||||
end_idx = self.window_size - self.step_size
|
||||
output_data = filtered[:, start_idx:end_idx].copy()
|
||||
return output_data
|
||||
return output_data, filtered
|
||||
|
||||
def run(self):
|
||||
"""线程主逻辑:精确200ms触发一次滤波"""
|
||||
interval = self.step_sec # 200ms = 0.2秒
|
||||
next_run_time = time.perf_counter()
|
||||
while self.running.is_set():
|
||||
# 1. 精确定时等待
|
||||
current_time = time.perf_counter()
|
||||
if current_time < next_run_time:
|
||||
time.sleep(next_run_time - current_time)
|
||||
next_run_time += interval
|
||||
else:
|
||||
algo_log("滤波耗时超过200ms,定时偏移", level='debug')
|
||||
next_run_time = time.perf_counter() + interval
|
||||
interval = self.step_sec # 0.2s
|
||||
# 以启动时刻为绝对时间基准(核心改动)
|
||||
base_time = time.perf_counter()
|
||||
frame_count = 0 # 帧计数器,用于对齐时序
|
||||
|
||||
# ========== 新增核心判断:无新数据则直接跳过 ==========
|
||||
while self.running.is_set():
|
||||
# 计算理论执行时刻:严格按帧序号 × 步长
|
||||
expect_time = base_time + frame_count * interval
|
||||
current_time = time.perf_counter()
|
||||
|
||||
# 精确定时等待
|
||||
if current_time < expect_time:
|
||||
time.sleep(expect_time - current_time)
|
||||
else:
|
||||
# 处理超时:仅告警,不重置基准(防止累积偏移)
|
||||
algo_log(f"滤波任务超时,偏移 {(current_time - expect_time)*1000:.1f} ms", level='debug')
|
||||
|
||||
frame_count += 1 # 帧序号自增,保证周期绝对稳定
|
||||
if not self.ring_buffer.check_and_clear_new_data():
|
||||
# 无新数据,不执行滤波、不发送数据
|
||||
continue
|
||||
|
||||
# 2. 有新数据,才执行原有滤波逻辑
|
||||
# ========== 原有滤波逻辑 ==========
|
||||
try:
|
||||
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
|
||||
if window_data is None:
|
||||
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}点", level='debug')
|
||||
continue
|
||||
if not self.slide_ready:
|
||||
# 阶段1:首次填满3s初始窗口
|
||||
full_data = self.ring_buffer.get_latest_n_points(self.window_size)
|
||||
if full_data is None:
|
||||
algo_log("初始窗口数据不足", level='debug')
|
||||
continue
|
||||
self.slide_window = full_data
|
||||
self.slide_ready = True
|
||||
else:
|
||||
# 阶段2:正常滑动 → 取最新50个新点,增量拼接
|
||||
new_step_data = self.ring_buffer.get_latest_n_points(self.step_size)
|
||||
if new_step_data is None:
|
||||
algo_log("滑动步长数据不足", level='debug')
|
||||
continue
|
||||
# 增量滑动:丢弃前50点,拼接新50点(标准滑动窗口)
|
||||
self.slide_window = np.hstack([
|
||||
self.slide_window[:, self.step_size:],
|
||||
new_step_data
|
||||
])
|
||||
|
||||
filtered_data = self._filter_window_data(window_data)
|
||||
# algo_log(f"滤波后{filtered_data.shape}数据", level='debug')
|
||||
filtered_data, filtered_full = self._filter_window_data(self.slide_window[:64, :])
|
||||
|
||||
# Beta PSD 每秒计算一次
|
||||
self._beta_step_counter += 1
|
||||
if self._beta_step_counter >= self._beta_steps_per_second:
|
||||
self._beta_step_counter = 0
|
||||
self._beta_thread.push_data(filtered_full[:2, :])
|
||||
|
||||
if self.filter_result_callback is not None:
|
||||
self.filter_result_callback(filtered_data[:64, :])
|
||||
self.filter_result_callback(filtered_data)
|
||||
except Exception as e:
|
||||
algo_log(f"滤波执行异常: {e}", level='error')
|
||||
|
||||
@@ -187,17 +302,11 @@ class SlidingFilter(threading.Thread):
|
||||
self.filter_result_callback = callback
|
||||
|
||||
def stop(self):
|
||||
"""停止滤波线程(安全版)"""
|
||||
# 1. 先设置停止标志(Event.clear()是线程安全的)
|
||||
"""停止滤波线程和 Beta 计算线程"""
|
||||
self._beta_thread.stop()
|
||||
self.running.clear()
|
||||
|
||||
# 2. 核心修复:只有线程已启动且正在运行时才调用join
|
||||
if self.is_alive():
|
||||
# 等待线程正常退出,最多1秒
|
||||
self.join(timeout=1)
|
||||
# 超时未退出时打印警告,便于排查问题
|
||||
if self.is_alive():
|
||||
algo_log("警告:滤波线程在1秒内未正常退出,可能存在阻塞操作", level="WARNING")
|
||||
|
||||
# 3. 无论线程是否启动,都打印停止日志
|
||||
algo_log("滤波线程已停止")
|
||||
|
||||
@@ -152,7 +152,8 @@ class zmqServer(threading.Thread):
|
||||
msg = {'method': method, 'params': params}
|
||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||
|
||||
algo_log(f"发送命令结果: {msg}", level="DEBUG")
|
||||
if msg['method'] != 'beta_psd':
|
||||
algo_log(f"发送命令结果: {msg}", level="DEBUG")
|
||||
|
||||
# 广播到所有命令客户端
|
||||
for client_id in list(self.cmd_clients):
|
||||
@@ -179,7 +180,7 @@ class zmqServer(threading.Thread):
|
||||
# 转置为上位机需要的[50, 通道数]格式
|
||||
filtered_data = filtered_data.T.astype(np.float64)
|
||||
send_buf = filtered_data.tobytes()
|
||||
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
|
||||
# algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
|
||||
self.data_send_queue.put(send_buf)
|
||||
|
||||
def _process_data_send_queue(self):
|
||||
@@ -225,6 +226,9 @@ class zmqServer(threading.Thread):
|
||||
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
|
||||
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
|
||||
return
|
||||
except Exception as e:
|
||||
algo_log(f"_handle_cmd_message exception: {e}", level="ERROR")
|
||||
return
|
||||
|
||||
algo_log(f"收到命令: {message}", level="INFO")
|
||||
method = message.get("method")
|
||||
@@ -270,6 +274,22 @@ class zmqServer(threading.Thread):
|
||||
elif params == 2: #停止解码
|
||||
self.IsExitApp = True
|
||||
self.running = False
|
||||
|
||||
resp = {
|
||||
"method": "predict_response",
|
||||
"params": {
|
||||
"code": 200,
|
||||
"message": "ok"
|
||||
}
|
||||
}
|
||||
try:
|
||||
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
|
||||
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
|
||||
algo_log(f"predict 命令已即时回复客户端 {ident}", level="DEBUG")
|
||||
except Exception as e:
|
||||
algo_log(f"predict 命令回复失败: {e}", level="ERROR")
|
||||
return
|
||||
|
||||
elif method == "rest":
|
||||
self.state_mode = 'rest'
|
||||
elif method == "impedance":
|
||||
@@ -353,24 +373,24 @@ class zmqServer(threading.Thread):
|
||||
def detect_event(self, samples):
|
||||
self.pack_contain_event = False
|
||||
# 第65通道为事件通道
|
||||
event = int(samples[-2][0])
|
||||
# for idx, event in enumerate(events):
|
||||
if event in self.events:
|
||||
new_key = "".join(
|
||||
[
|
||||
str(event),
|
||||
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||
-%H-%M-%S"),
|
||||
]
|
||||
)
|
||||
self.currentLabel = event
|
||||
if event == self.predict_event:
|
||||
self.count_events[new_key] = self.latency + 1
|
||||
else:
|
||||
self.count_events[new_key] = self.train_latency + 1
|
||||
self.event_inner_idx = self.device_info['frame_points'] - 1
|
||||
# algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG")
|
||||
self.pack_contain_event = True
|
||||
events = np.array(samples[-2], dtype=np.int32).tolist()
|
||||
for idx, event in enumerate(events):
|
||||
if event in self.events:
|
||||
new_key = "".join(
|
||||
[
|
||||
str(event),
|
||||
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||
-%H-%M-%S"),
|
||||
]
|
||||
)
|
||||
self.currentLabel = event
|
||||
if event == self.predict_event:
|
||||
self.count_events[new_key] = self.latency + 1
|
||||
else:
|
||||
self.count_events[new_key] = self.train_latency + 1
|
||||
self.event_inner_idx = idx
|
||||
algo_log(f"事件检测到: {events},索引: {idx}", level="DEBUG")
|
||||
self.pack_contain_event = True
|
||||
|
||||
# 倒计时并清理过期事件
|
||||
drop_items = []
|
||||
@@ -405,13 +425,18 @@ class zmqServer(threading.Thread):
|
||||
frames = self.cmd_socket.recv_multipart()
|
||||
self._handle_cmd_message(frames)
|
||||
|
||||
# 处理8100数据端口消息
|
||||
# 处理8100数据端口消息(排空积压,消除标签延迟)
|
||||
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
||||
frames = self.data_socket.recv_multipart()
|
||||
self._handle_data_message(frames)
|
||||
while True:
|
||||
try:
|
||||
frames = self.data_socket.recv_multipart(zmq.NOBLOCK)
|
||||
self._handle_data_message(frames)
|
||||
except zmq.Again:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
algo_log(f"服务器主循环异常: {e}", level="ERROR")
|
||||
algo_log(f"服务器主循环异常: {str(e)}", level="ERROR")
|
||||
return
|
||||
finally:
|
||||
self.running = False
|
||||
# 优雅关闭所有资源
|
||||
|
||||
16
config.ini
16
config.ini
@@ -18,11 +18,23 @@ Upper_Port = 8088
|
||||
Decoder_Host = 127.0.0.1
|
||||
Decoder_Port = 8099
|
||||
Serial_port = COM44
|
||||
algo_log_level = DEBUG
|
||||
console_output = 1
|
||||
save_train_data = 0
|
||||
zmqServer_host = 127.0.0.1
|
||||
|
||||
[algo_log]
|
||||
# ========== 文件日志配置 ==========
|
||||
file_log_enable = true
|
||||
file_log_level = DEBUG
|
||||
log_path = exe
|
||||
retention_days = 3
|
||||
|
||||
# ========== 控制台/黑框配置 ==========
|
||||
console_enable = true
|
||||
console_show_window = true
|
||||
console_log_level = DEBUG
|
||||
|
||||
|
||||
|
||||
; 64 导设备配置
|
||||
[device_type_1]
|
||||
sample_rate = 250
|
||||
|
||||
@@ -56,7 +56,7 @@ EMPTY_FRAME = b""
|
||||
|
||||
# 仿真信号配置
|
||||
TARGET_CHANNEL = 0
|
||||
SIGNAL_FREQ_LIST = [3, 13]
|
||||
SIGNAL_FREQ_LIST = [13]
|
||||
SIGNAL_AMP = 1.8
|
||||
NOISE_GAUSSIAN_AMP = 0.4
|
||||
NOISE_POWER50_AMP = 0.3
|
||||
@@ -128,8 +128,8 @@ def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
|
||||
sig = 0.0
|
||||
for freq in SIGNAL_FREQ_LIST:
|
||||
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
|
||||
sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
|
||||
sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
|
||||
# sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
|
||||
# sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
|
||||
data[:, ch] = sig
|
||||
|
||||
# 事件通道、保留通道
|
||||
|
||||
164
logs/log.py
164
logs/log.py
@@ -1,114 +1,172 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import inspect
|
||||
try:
|
||||
import win32gui
|
||||
import win32con
|
||||
WIN32_AVAILABLE = True
|
||||
except ImportError:
|
||||
WIN32_AVAILABLE = False
|
||||
|
||||
from PubLibrary.InifileHelper import IniRead
|
||||
|
||||
# 全局配置
|
||||
console_output = IniRead('system', 'console_output', '1')
|
||||
log_level = IniRead('system', 'algo_log_level', 'INFO')
|
||||
# ===================== 新增:获取 EXE 同级目录 =====================
|
||||
def get_app_root():
|
||||
"""获取 runDecoder.exe 所在的真实根目录(兼容 onefile / standalone)"""
|
||||
if getattr(sys, 'frozen', False):
|
||||
# Nuitka / PyInstaller 打包后走这里
|
||||
app_path = sys.executable
|
||||
else:
|
||||
# 本地源码运行时,取当前脚本目录
|
||||
app_path = os.path.abspath(__file__)
|
||||
return os.path.dirname(app_path)
|
||||
|
||||
# 程序根目录(exe 同级)
|
||||
APP_ROOT = Path(get_app_root())
|
||||
# 日志文件夹名:exe 同级下 logs 目录
|
||||
DEFAULT_LOG_DIR = APP_ROOT / "logs"
|
||||
|
||||
# ===================== 读取 [algo_log] 配置 =====================
|
||||
# 文件日志
|
||||
FILE_LOG_ENABLE = IniRead("algo_log", "file_log_enable", "true").lower() == "true"
|
||||
FILE_LOG_LEVEL = IniRead("algo_log", "file_log_level", "DEBUG").upper()
|
||||
# 优先级:配置文件 > 默认exe同级logs
|
||||
CFG_LOG_PATH = IniRead("algo_log", "log_path", "").strip()
|
||||
if CFG_LOG_PATH == "exe":
|
||||
LOG_DIR = DEFAULT_LOG_DIR
|
||||
else:
|
||||
LOG_DIR = Path(CFG_LOG_PATH)
|
||||
|
||||
LOG_RETENTION_DAYS = int(IniRead("algo_log", "retention_days", 3))
|
||||
|
||||
# 控制台日志 + 黑框控制
|
||||
CONSOLE_ENABLE = IniRead("algo_log", "console_enable", "true").lower() == "true"
|
||||
CONSOLE_SHOW_WINDOW = IniRead("algo_log", "console_show_window", "true").lower() == "true"
|
||||
CONSOLE_LOG_LEVEL = IniRead("algo_log", "console_log_level", "INFO").upper()
|
||||
|
||||
# ===================== 全局常量与缓存 =====================
|
||||
log_once_cache = set()
|
||||
logger_cache = {}
|
||||
LOG_RETENTION_DAYS = 3
|
||||
LOG_DIR = './logs/'
|
||||
LOG_FILE_PREFIX = 'algo_log_'
|
||||
# 确保日志目录存在
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
LOG_DIR_STR = str(LOG_DIR) + "\\"
|
||||
|
||||
# 日志格式:时间 - 日志器名 - 级别 - 文件名:行号 - 函数名 - 日志内容
|
||||
# 日志格式
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
# 日志级别映射
|
||||
LEVEL_MAP = {
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARNING": logging.WARNING,
|
||||
"ERROR": logging.ERROR,
|
||||
"FATAL": logging.FATAL
|
||||
}
|
||||
FILE_LOG_LEVEL_INT = LEVEL_MAP.get(FILE_LOG_LEVEL, logging.INFO)
|
||||
CONSOLE_LOG_LEVEL_INT = LEVEL_MAP.get(CONSOLE_LOG_LEVEL, logging.INFO)
|
||||
|
||||
def clean_old_logs():
|
||||
"""清理超过指定天数的旧日志文件"""
|
||||
# ===================== Windows 控制台黑框显示/隐藏 =====================
|
||||
def control_console_window():
|
||||
if not sys.platform.startswith("win") or not WIN32_AVAILABLE:
|
||||
return
|
||||
try:
|
||||
if not os.path.exists(LOG_DIR):
|
||||
hwnd = win32gui.GetForegroundWindow()
|
||||
if CONSOLE_SHOW_WINDOW:
|
||||
win32gui.ShowWindow(hwnd, win32con.SW_SHOW)
|
||||
else:
|
||||
win32gui.ShowWindow(hwnd, win32con.SW_HIDE)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
control_console_window()
|
||||
|
||||
# ===================== 清理过期日志 =====================
|
||||
def clean_old_logs():
|
||||
try:
|
||||
if not LOG_DIR.exists():
|
||||
return
|
||||
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
|
||||
for filename in os.listdir(LOG_DIR):
|
||||
if not filename.startswith(LOG_FILE_PREFIX) or not filename.endswith('.log'):
|
||||
if not (filename.startswith(LOG_FILE_PREFIX) and filename.endswith('.log')):
|
||||
continue
|
||||
date_str = filename[len(LOG_FILE_PREFIX):-4]
|
||||
try:
|
||||
file_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||
if file_date < expire_date:
|
||||
file_path = os.path.join(LOG_DIR, filename)
|
||||
file_path = LOG_DIR / filename
|
||||
os.remove(file_path)
|
||||
print(f"清理过期日志: {file_path}")
|
||||
except ValueError:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"清理旧日志异常: {str(e)}")
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ===================== 初始化日志器 =====================
|
||||
def init_module_logger(logger_name):
|
||||
"""初始化日志器 + 清理旧日志"""
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
clean_old_logs()
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
log_file = os.path.join(LOG_DIR, f"{LOG_FILE_PREFIX}{current_date}.log")
|
||||
|
||||
if logger_name in logger_cache:
|
||||
return logger_cache[logger_name]
|
||||
|
||||
clean_old_logs()
|
||||
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(log_level)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
if logger.handlers:
|
||||
logger_cache[logger_name] = logger
|
||||
return logger
|
||||
|
||||
# 文件输出处理器
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=10,
|
||||
encoding='utf-8'
|
||||
)
|
||||
formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 控制台输出
|
||||
if console_output:
|
||||
console_handler = logging.StreamHandler()
|
||||
# 文件日志
|
||||
if FILE_LOG_ENABLE:
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
log_file = LOG_DIR / f"{LOG_FILE_PREFIX}{current_date}.log"
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=10,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(FILE_LOG_LEVEL_INT)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 控制台日志
|
||||
if CONSOLE_ENABLE:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(CONSOLE_LOG_LEVEL_INT)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
logger_cache[logger_name] = logger
|
||||
return logger
|
||||
|
||||
|
||||
# ===================== 对外日志入口函数 =====================
|
||||
def algo_log(content, level="INFO", record_once=False):
|
||||
"""
|
||||
日志入口函数
|
||||
自动记录:调用文件名、代码行号、所在函数
|
||||
"""
|
||||
# 回溯栈帧,获取真正调用 algo_log 的代码位置
|
||||
# f_back(1) -> algo_log 自身,f_back(2) -> 业务调用处
|
||||
frame = inspect.currentframe().f_back.f_back
|
||||
if not frame:
|
||||
file_name = "unknown"
|
||||
else:
|
||||
file_name = os.path.basename(frame.f_code.co_filename)
|
||||
frame = inspect.currentframe()
|
||||
if frame:
|
||||
frame = frame.f_back.f_back
|
||||
file_name = os.path.basename(frame.f_code.co_filename) if frame else "unknown"
|
||||
|
||||
logger = init_module_logger(file_name)
|
||||
|
||||
# 单次日志去重
|
||||
if record_once:
|
||||
log_key = f"{level.upper()}_{content}"
|
||||
if log_key in log_once_cache:
|
||||
return
|
||||
log_once_cache.add(log_key)
|
||||
|
||||
# 日志级别分发
|
||||
level_upper = level.upper()
|
||||
log_map = {
|
||||
log_func_map = {
|
||||
"DEBUG": logger.debug,
|
||||
"INFO": logger.info,
|
||||
"WARNING": logger.warning,
|
||||
"ERROR": logger.error,
|
||||
"FATAL": logger.fatal,
|
||||
"INFO": logger.info
|
||||
"FATAL": logger.fatal
|
||||
}
|
||||
log_func = log_map.get(level_upper, logger.info)
|
||||
log_func = log_func_map.get(level_upper, logger.info)
|
||||
log_func(content)
|
||||
@@ -28,7 +28,6 @@ echo "输出目录:${OUT_DIR}"
|
||||
python -m nuitka \
|
||||
--standalone \
|
||||
--msvc=latest \
|
||||
--windows-console-mode=disable \
|
||||
--module-parameter=torch-disable-jit=yes \
|
||||
--enable-plugin=no-qt \
|
||||
--include-package=numpy \
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import argparse
|
||||
import sys
|
||||
# import matplotlib
|
||||
# matplotlib.use('Agg')
|
||||
# import argparse
|
||||
# import sys
|
||||
import time
|
||||
from Decoder import Decoder_main
|
||||
from PubLibrary.RunOnce import is_program_running
|
||||
|
||||
@@ -170,6 +170,7 @@ def run_headless():
|
||||
|
||||
time.sleep(1) # 等待连接建立
|
||||
client.send_data('decoderClass', 'mi')
|
||||
time.sleep(4) # 等待 zmqServer 排空启动积压包(datamock 提前连接会积压 ~3s 数据)
|
||||
|
||||
# MI_IntervalEpoch = [0.5, 4.5],trial时长 = 4.5-0.5 = 4.0s
|
||||
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||
@@ -222,7 +223,7 @@ def run_headless():
|
||||
time.sleep(0.5) # ding 提示后等待
|
||||
|
||||
client.send_data('train', 0)
|
||||
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
|
||||
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
@@ -231,7 +232,7 @@ def run_headless():
|
||||
# 空闲态样本采集(train 1,label=2)
|
||||
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
|
||||
client.send_data('train', 1)
|
||||
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
|
||||
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
|
||||
|
||||
trained += 1
|
||||
client.send_data('rest', 0)
|
||||
|
||||
Reference in New Issue
Block a user