Files
bci_algo/concentration/algorithm/calculate_focus.py
2026-06-05 09:34:29 +08:00

396 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
from scipy.signal import welch
from scipy.fft import fft
from scipy import signal
from collections import deque
import time
import os
# import logging
import base64
import io
# logger = logging.getLogger(__name__)
#
# try:
# import matplotlib
# matplotlib.use('Agg')
# import matplotlib.pyplot as plt
# MATPLOTLIB_AVAILABLE = True
# except ImportError:
# MATPLOTLIB_AVAILABLE = False
# logger.warning("matplotlib未安装报告图表功能不可用")
class Calculate():
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10):
self.Threshold_value_low = Threshold_value_low
self.Threshold_value_high = Threshold_value_high
self.fs = fs
self.focus_result = []
self.CLI_result = []
self.EVI_result = []
self.eegQueue = deque(maxlen=win_len)
# # 存储历史数据用于绘图
# self.beta_history = []
# self.alpha_history = []
# self.theta_history = []
# self.focus_history = []
# self.timestamp_history = []
#
# # 记录开始时间
# self.start_time = None
# self.recording = False
#
# # 图表保存路径
# self.chart_dir = "reports"
# if not os.path.exists(self.chart_dir):
# os.makedirs(self.chart_dir)
# print(f"[调试] 创建目录: {self.chart_dir}")
# 初始化滤波器
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
print("[调试] Calculate 类初始化完成")
def calculate_focus(self, beta, alpha, theta):
"""
专注度计算 - 固定映射版本
"""
# 原始比值
raw = beta / (alpha + theta + 1e-10)
# Sigmoid 映射:让 raw 在 0.3-1.5 区间敏感
# 参数可调:
# k = 12 (斜率,越大越陡)
# x0 = 0.6 (中心点raw=0.6时focus≈50)
k = 12.0
x0 = 0.6
focus = 100.0 / (1.0 + np.exp(-k * (raw - x0)))
# 可选:添加滑动平均平滑
return int(focus)
def calculate_all(self, data, fs, nperseg=1000):
mean_x = np.mean(data, axis=-1, keepdims=True)
data = data - mean_x
freqs, psd = self.compute_psd_multichannel(data, fs, nperseg)
beta_psd = np.sum(self.band_psd(freqs, psd, (13, 30)))
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
focus_score = self.calculate_focus(beta_psd, alpha_psd, theta_psd)
focus_score = max(0, min(100, focus_score))
self.focus_result.append(focus_score)
if len(self.focus_result) > 3:
self.focus_result.pop(0)
final_focus = int(self.simple_moving_average(self.focus_result, window_size=5))
cli_denom = alpha_psd + beta_psd
CLI_score = np.log(theta_psd / (cli_denom + 1e-10)) if cli_denom > 0 else 0
self.CLI_result.append(CLI_score)
if len(self.CLI_result) > 5:
self.CLI_result.pop(0)
final_CLI = round(self.simple_moving_average(self.CLI_result, window_size=5), 2)
return final_focus, final_CLI, beta_psd, alpha_psd, theta_psd
def compute_psd_multichannel(self, data, fs=250, nperseg=1000):
n_samples = data.shape[-1]
if n_samples < nperseg:
nperseg = n_samples
noverlap = 500
if noverlap >= nperseg:
noverlap = int(nperseg / 2)
if nperseg == 0:
return np.array([]), np.zeros((data.shape[0], 0))
freqs, psd = welch(data, fs=fs, nperseg=nperseg, noverlap=noverlap, axis=-1)
return freqs, psd
def band_psd(self, freqs, psd, band):
idx = np.logical_and(freqs >= band[0], freqs <= band[1])
return np.sum(psd[:, idx], axis=-1)
def simple_moving_average(self, data, window_size=5):
if len(data) == 0:
return 30
window = data[-window_size:]
return sum(window) / len(window)
def reset_queue(self):
self.eegQueue.clear()
# def start_recording(self):
# """开始记录数据"""
# self.recording = True
# self.start_time = time.time()
# self.beta_history = []
# self.alpha_history = []
# self.theta_history = []
# self.focus_history = []
# self.timestamp_history = []
# print("[调试] ========== 开始记录专注度数据 ==========")
# def stop_recording(self):
# """停止记录并生成图表"""
# print(f"[调试] stop_recording被调用, recording={self.recording}, focus_history长度={len(self.focus_history)}")
# self.recording = False
# if len(self.focus_history) > 0:
# print("[调试] 数据非空,开始生成图表...")
# # 保存到本地文件
# chart_path = self.save_chart_to_file()
# if chart_path:
# print(f"[调试] 本地文件保存成功: {chart_path}")
# else:
# print("[调试] 本地文件保存失败")
# # 生成base64编码
# base64_data = self.generate_chart_base64()
# return base64_data
# else:
# print("[调试] 没有数据可保存focus_history为空")
# return None
# def add_data_point(self, focus, beta, alpha, theta):
# if not self.recording:
# return
# current_time = time.time()
# elapsed = current_time - self.start_time
#
# self.beta_history.append(beta)
# self.alpha_history.append(alpha)
# self.theta_history.append(theta)
# self.focus_history.append(focus)
# self.timestamp_history.append(elapsed)
# print(f"[调试] 记录数据点: time={elapsed:.1f}s, focus={focus}, beta={beta:.2f}")
# def save_chart_to_file(self):
# """
# 保存图表到本地文件(唯一实现)
# """
# print(f"[调试] save_chart_to_file被调用, MATPLOTLIB_AVAILABLE={MATPLOTLIB_AVAILABLE}")
#
# if not MATPLOTLIB_AVAILABLE:
# print("[调试] matplotlib不可用无法保存")
# return None
#
# if len(self.focus_history) < 2:
# print(f"[调试] 数据点不足需要至少2个点当前{len(self.focus_history)}个点")
# return None
#
# print(f"[调试] 开始保存图表到本地文件...")
#
# # 确保所有列表长度一致
# min_len = min(len(self.beta_history), len(self.alpha_history),
# len(self.theta_history), len(self.focus_history),
# len(self.timestamp_history))
#
# print(f"[调试] 数据长度: min_len={min_len}")
#
# beta_list = self.beta_history[:min_len]
# alpha_list = self.alpha_history[:min_len]
# theta_list = self.theta_history[:min_len]
# focus_list = self.focus_history[:min_len]
# times = self.timestamp_history[:min_len]
#
# # 生成文件名
# timestamp = time.strftime("%Y%m%d_%H%M%S")
# chart_path = os.path.join(self.chart_dir, f"concentration_report_{timestamp}.png")
# print(f"[调试] 保存路径: {chart_path}")
#
# try:
# # 创建图表
# fig, ax1 = plt.subplots(figsize=(14, 8))
#
# # 左Y轴功率数据
# ax1.plot(times, beta_list, 'b-', linewidth=1.5, alpha=0.8, label='Beta Power')
# ax1.plot(times, alpha_list, 'g-', linewidth=1.5, alpha=0.8, label='Alpha Power')
# ax1.plot(times, theta_list, 'orange', linewidth=1.5, alpha=0.8, label='Theta Power')
# ax1.set_xlabel('Time (s)', fontsize=12)
# ax1.set_ylabel('Band Power', fontsize=12, color='black')
# ax1.tick_params(axis='y', labelcolor='black')
# ax1.legend(loc='upper left')
# ax1.grid(True, alpha=0.3)
#
# # 右Y轴专注度
# ax2 = ax1.twinx()
# ax2.plot(times, focus_list, 'r-', linewidth=2, alpha=0.9, label='Focus (%)')
# ax2.set_ylabel('Focus (%)', fontsize=12, color='red')
# ax2.tick_params(axis='y', labelcolor='red')
# ax2.set_ylim(0, 105)
# ax2.legend(loc='upper right')
#
# # 标题
# duration = times[-1] if times else 0
# avg_focus = np.mean(focus_list) if focus_list else 0
# plt.title(f'Concentration and EEG Band Power Trend\nDuration: {duration:.1f}s, Avg Focus: {avg_focus:.1f}%',
# fontsize=14)
#
# plt.tight_layout()
# plt.savefig(chart_path, dpi=150, bbox_inches='tight')
# plt.close()
#
# print(f"\n{'='*60}")
# print(f"专注度报告图片已保存到本地:")
# print(f" 文件路径: {chart_path}")
# print(f" 数据点数: {min_len}")
# print(f" 时长: {duration:.1f}秒")
# print(f" 平均专注度: {avg_focus:.1f}%")
# print(f"{'='*60}\n")
#
# return chart_path
#
# except Exception as e:
# print(f"[调试] 保存文件时出错: {e}")
# import traceback
# traceback.print_exc()
# return None
#
# def generate_chart_base64(self):
# """
# 生成图表的base64编码用于网络传输
# """
# if not MATPLOTLIB_AVAILABLE:
# return None
#
# if len(self.focus_history) < 2:
# return None
#
# min_len = min(len(self.beta_history), len(self.alpha_history),
# len(self.theta_history), len(self.focus_history),
# len(self.timestamp_history))
#
# beta_list = self.beta_history[:min_len]
# alpha_list = self.alpha_history[:min_len]
# theta_list = self.theta_history[:min_len]
# focus_list = self.focus_history[:min_len]
# times = self.timestamp_history[:min_len]
#
# fig, ax1 = plt.subplots(figsize=(14, 8))
#
# ax1.plot(times, beta_list, 'b-', linewidth=1.5, alpha=0.8, label='Beta Power')
# ax1.plot(times, alpha_list, 'g-', linewidth=1.5, alpha=0.8, label='Alpha Power')
# ax1.plot(times, theta_list, 'orange', linewidth=1.5, alpha=0.8, label='Theta Power')
# ax1.set_xlabel('Time (s)', fontsize=12)
# ax1.set_ylabel('Band Power', fontsize=12, color='black')
# ax1.tick_params(axis='y', labelcolor='black')
# ax1.legend(loc='upper left')
# ax1.grid(True, alpha=0.3)
#
# ax2 = ax1.twinx()
# ax2.plot(times, focus_list, 'r-', linewidth=2, alpha=0.9, label='Focus (%)')
# ax2.set_ylabel('Focus (%)', fontsize=12, color='red')
# ax2.tick_params(axis='y', labelcolor='red')
# ax2.set_ylim(0, 105)
# ax2.legend(loc='upper right')
#
# duration = times[-1] if times else 0
# avg_focus = np.mean(focus_list) if focus_list else 0
# plt.title(f'Concentration and EEG Band Power Trend\nDuration: {duration:.1f}s, Avg Focus: {avg_focus:.1f}%',
# fontsize=14)
#
# plt.tight_layout()
#
# buffer = io.BytesIO()
# plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
# buffer.seek(0)
# image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
# plt.close()
#
# return image_base64
def queueOpt(self, data):
if data is None or data.size == 0:
return None
if len(self.eegQueue) < self.eegQueue.maxlen:
self.eegQueue.append(data)
else:
self.eegQueue.append(data)
if len(self.eegQueue) == self.eegQueue.maxlen:
eegData = np.hstack([self.eegQueue[i] for i in range(len(self.eegQueue))])
if eegData.size == 0:
return None
eegData -= np.mean(eegData, axis=-1, keepdims=True)
eegData = signal.lfilter(self.b_notch, self.a_notch, eegData)
eegData = signal.lfilter(self.b_design, 1, eegData)
focus_score, CLI_score, beta, alpha, theta = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
# self.add_data_point(focus_score, beta, alpha, theta)
return focus_score
return None
class Calculate2():
def __init__(self, Threshold_value_low, Threshold_value_high):
self.Threshold_value_low = Threshold_value_low
self.Threshold_value_high = Threshold_value_high
self.focus_result = []
self.theta_result = []
self.alpha_result = []
self.flow_result = []
def calculate_all(self, data, fs, L=2500):
mean_x = np.mean(data, axis=-1, keepdims=True)
data = data - mean_x
Y = fft(data, axis=-1)
P2 = np.abs(Y / L)
P1 = P2[:, :L // 2 + 1]
P1[:, 1:-1] = 2 * P1[:, 1:-1]
beta_power = self.PSD(P1, L, fs, 13, 30)
alpha_power = self.PSD(P1, L, fs, 8, 13)
theta_power = self.PSD(P1, L, fs, 4, 8)
gamma_power = self.PSD(P1, L, fs, 30, 100)
focus_score = beta_power / (alpha_power + theta_power)
print('focus score:', focus_score)
focus_score = ((focus_score - self.Threshold_value_low) * 100) / (self.Threshold_value_high - self.Threshold_value_low)
self.focus_result.append(focus_score)
if len(self.focus_result) > 3:
self.focus_result.pop(0)
final_focus = int(self.simple_moving_average(self.focus_result, window_size=3))
self.theta_result.append(theta_power)
if len(self.theta_result) > 30:
self.theta_result.pop(0)
self.alpha_result.append(alpha_power)
if len(self.alpha_result) > 30:
self.alpha_result.pop(0)
rest_theta = self.simple_moving_average(self.theta_result, window_size=30)
rest_alpha = self.simple_moving_average(self.alpha_result, window_size=30)
distraction_score = (theta_power / rest_theta) * (1 - (alpha_power / rest_alpha))
flow_score = gamma_power / beta_power
flow_score = (flow_score / self.Threshold_value_high) * 100
self.flow_result.append(flow_score)
if len(self.flow_result) > 3:
self.flow_result.pop(0)
final_flow = int(self.simple_moving_average(self.flow_result, window_size=3))
return final_focus, distraction_score, final_flow
def PSD(self, P1, L, Fs, s_freq, e_freq):
s_point = round(s_freq * L / Fs)
e_point = round(e_freq * L / Fs)
x, y = P1.shape
band_PSD = 0
for i in range(x):
for j in range(s_point, e_point):
band_PSD += P1[i, j] ** 2
return band_PSD
def simple_moving_average(self, data, window_size=3):
if len(data) == 0:
return []
window = data[-window_size:]
return sum(window) / len(window)