From 494515463d518f64a2310b5d817f776509753ec3 Mon Sep 17 00:00:00 2001 From: Ivey Song Date: Sat, 6 Jun 2026 14:57:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=93=E6=B3=A8=E5=8A=9B=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- concentration/algorithm/calculate_focus.py | 113 +++++++++++++-------- 1 file changed, 71 insertions(+), 42 deletions(-) diff --git a/concentration/algorithm/calculate_focus.py b/concentration/algorithm/calculate_focus.py index b042398..861debf 100644 --- a/concentration/algorithm/calculate_focus.py +++ b/concentration/algorithm/calculate_focus.py @@ -8,6 +8,7 @@ import os # import logging import base64 import io +import math # logger = logging.getLogger(__name__) # @@ -22,7 +23,7 @@ import io class Calculate(): - def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10): + def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10, config=None): self.Threshold_value_low = Threshold_value_low self.Threshold_value_high = Threshold_value_high self.fs = fs @@ -30,48 +31,74 @@ class Calculate(): self.CLI_result = [] self.EVI_result = [] self.eegQueue = deque(maxlen=win_len) - - # # 存储历史数据用于绘图 - # self.beta_history = [] - # self.alpha_history = [] - # self.theta_history = [] - # self.focus_history = [] - # self.timestamp_history = [] - # - # # 记录开始时间 - # self.start_time = None - # self.recording = False - # - # # 图表保存路径 - # self.chart_dir = "reports" - # if not os.path.exists(self.chart_dir): - # os.makedirs(self.chart_dir) - # print(f"[调试] 创建目录: {self.chart_dir}") - + # 初始化滤波器 self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False) - + + self.last_focus = None + # 异步滤波系数配置(核心手感控制纽) + self.alpha_up = 1 # 上升系数:较小,保证分数平滑爬升,过滤偶发的瞬时高能量 + # alpha_down / shrink_factor 从 config.ini 读取,方便上位机调参 + if config: + self.alpha_down = float(config.get('alpha_down', 0.8)) + self.shrink_factor = float(config.get('shrink_factor', 0.5)) + else: + self.alpha_down = 0.8 + self.shrink_factor = 0.5 print("[调试] Calculate 类初始化完成") - + def calculate_focus(self, beta, alpha, theta): """ - 专注度计算 - 固定映射版本 + 专注度计算 - 三区间门限异步滤波版本 """ + # 0. 频带特征预处理 + theta_mod = theta ** 0.7 + # 原始比值 - raw = beta / (alpha + theta + 1e-10) - - # Sigmoid 映射:让 raw 在 0.3-1.5 区间敏感 - # 参数可调: - # k = 12 (斜率,越大越陡) - # x0 = 0.6 (中心点,raw=0.6时focus≈50) - k = 12.0 - x0 = 0.6 - focus = 100.0 / (1.0 + np.exp(-k * (raw - x0))) - - # 可选:添加滑动平均平滑 + raw = beta / (alpha + theta_mod + 1e-10) + + exponent = 2.0 + + # 1. 防止脑电比值出现负数异常值 + raw_input = max(raw, 0.0) + + # 2. 2次幂纵轴压缩映射 (shrink_factor 从 config.ini 读取) + focus_raw = 100 * self.shrink_factor * (raw_input ** exponent) + + # 3. 计算当前帧的瞬时分数 (基准量级 0-120) + instant_focus = 120 * (1.0 - np.exp(-focus_raw / 100.0)) + + # 4. 核心修改:三区间门限时域滤波 + if self.last_focus is None: + # 冷启动:首帧直接赋值 + focus = instant_focus + else: + # 判断当前瞬时分数是否处于【极端区】(80以上 或 60以下) + if instant_focus > 85.0 or instant_focus < 60.0: + # 执行异步低通时域滤波 + if instant_focus >= self.last_focus: + # 趋势上升:慢爬升 + focus = self.alpha_up * instant_focus + (1 - self.alpha_up) * self.last_focus + else: + # 趋势下降:快跌落 + focus = self.alpha_down * instant_focus + (1 - self.alpha_down) * self.last_focus + else: + # 【高灵敏自由区】(60 <= instant_focus <= 80) + # 不执行异步滤波,分数直接跟随瞬时值,保证中间状态绝对跟手 + focus = instant_focus + + # 5. 更新历史状态缓存 + self.last_focus = focus + + # 打印在线调试日志,方便观察区间切换 + zone_tag = "极端区(滤波)" if (instant_focus > 80 or instant_focus < 60) else "自由区(直通)" + print(f"原始特征比值 raw: {raw:.4f} | 瞬时分数: {instant_focus:.1f} | 滤波后分数: {focus:.1f}") + + # 最终返回整型 return int(focus) + def calculate_all(self, data, fs, nperseg=1000): mean_x = np.mean(data, axis=-1, keepdims=True) data = data - mean_x @@ -90,7 +117,7 @@ class Calculate(): if len(self.focus_result) > 3: self.focus_result.pop(0) final_focus = int(self.simple_moving_average(self.focus_result, window_size=5)) - + cli_denom = alpha_psd + beta_psd CLI_score = np.log(theta_psd / (cli_denom + 1e-10)) if cli_denom > 0 else 0 self.CLI_result.append(CLI_score) @@ -319,14 +346,16 @@ class Calculate(): if eegData.size == 0: return None eegData -= np.mean(eegData, axis=-1, keepdims=True) - eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) - eegData = signal.lfilter(self.b_design, 1, eegData) - focus_score, CLI_score, beta, alpha, theta = self.calculate_all(eegData, fs=self.fs, nperseg=1000) - - # self.add_data_point(focus_score, beta, alpha, theta) - - return focus_score - return None + # eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) # 陷波 + # eegData = signal.lfilter(self.b_design, 1, eegData) # 滤波 + focus_score, CLI_score, beta_psd, alpha_psd, theta_psd = self.calculate_all(eegData, fs=self.fs, nperseg=1000) + + # self.add_data_point(focus_score, beta_psd, alpha_psd, theta_psd) # 已注释(方法已移除) + + # return (focus_score) + return (focus_score, beta_psd) + # return None + class Calculate2():