Files
bci_algo/Zmq/filterProcess.py
2026-06-12 11:33:48 +08:00

293 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*-coding:utf-8 -*-
"""
数据滤波模块
"""
import numpy as np
import time
import threading
import queue
from scipy import signal
from logs.log import algo_log
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Tools.beta_calculate import Beta_Calculate
class FilterRingBuffer:
def __init__(self, n_chan, n_points):
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
self.current_ptr = 0
self.total_samples = 0
self.lock = threading.Lock() # 仅保护元数据
self.has_new_data = False
def appendBuffer(self, data):
n = data.shape[1]
if n == 0:
return
# 仅加锁读取/更新元数据
with self.lock:
old_ptr = self.current_ptr
new_ptr = (old_ptr + n) % self.n_points
new_total = min(self.total_samples + n, self.n_points)
self.has_new_data = True
# 数组写入(耗时操作,移出锁外)
write_end = old_ptr + n
if write_end <= self.n_points:
self.buffer[:, old_ptr:write_end] = data
else:
split = self.n_points - old_ptr
self.buffer[:, old_ptr:] = data[:, :split]
self.buffer[:, :write_end - self.n_points] = data[:, split:]
# 再次加锁更新最终元数据
with self.lock:
self.current_ptr = new_ptr
self.total_samples = new_total
# ========== 新增:获取&清空新数据标记的方法 ==========
def check_and_clear_new_data(self):
"""检查是否有新数据,并一次性清空标记(消费后重置)"""
with self.lock:
flag = self.has_new_data
if flag:
self.has_new_data = False
return flag
def getData(self, count):
# 加锁获取最新元数据
with self.lock:
count = min(count, self.total_samples)
if count == 0:
return np.zeros((self.n_chan, 0))
end = self.current_ptr
start = end - count
# 数据读取、切片、拼接(无锁)
if start >= 0:
res = self.buffer[:, start:end].copy()
else:
part1 = self.buffer[:, start:]
part2 = self.buffer[:, :end]
res = np.concatenate((part1, part2), axis=1).copy()
return res
def get_latest_n_points(self, n):
with self.lock:
if self.total_samples < n:
return None
return self.getData(n)
def GetDataLenCount(self):
with self.lock:
return self.total_samples
def resetAllPara(self):
with self.lock:
self.buffer.fill(0.0)
self.current_ptr = 0
self.total_samples = 0
self.has_new_data = False # 重置时清空新数据标记
# -----------------------------------------------------------------------------
# 2. 独立 Beta PSD 计算线程(避免阻塞滤波主循环的 200ms 定时)
# -----------------------------------------------------------------------------
class BetaPsdCalculator(threading.Thread):
"""独立的 Beta PSD 计算线程,使用队列与滤波主线程解耦"""
def __init__(self, fs=250, window_size=750):
super().__init__(daemon=True)
self.fs = fs
self.window_size = window_size
self._beta_calc = Beta_Calculate(Threshold_value_low=0, Threshold_value_high=0, fs=fs)
self._input_queue = queue.Queue(maxsize=2)
self._running = threading.Event()
self._running.set()
self._latest_beta = None
self._beta_lock = threading.Lock()
self.beta_broadcast_callback = None
def push_data(self, data):
"""供外部调用的线程安全数据推送接口"""
try:
self._input_queue.put_nowait(data)
except queue.Full:
try:
self._input_queue.get_nowait()
except queue.Empty:
pass
try:
self._input_queue.put_nowait(data)
except queue.Full:
pass
def get_latest_beta(self):
"""获取最新的 beta 值(线程安全)"""
with self._beta_lock:
return self._latest_beta
def run(self):
while self._running.is_set():
try:
data = self._input_queue.get(timeout=1.5)
if data is None:
break
try:
beta_psd, _, _ = self._beta_calc.calculate_all(
data, fs=self.fs, nperseg=min(self.window_size, data.shape[1])
)
with self._beta_lock:
self._latest_beta = round(float(beta_psd), 3)
if self.beta_broadcast_callback is not None:
self.beta_broadcast_callback(self._latest_beta)
except Exception as e:
algo_log(f"Beta PSD 计算异常: {e}", level='error')
except queue.Empty:
pass
def stop(self):
"""停止计算线程"""
self._running.clear()
try:
self._input_queue.put_nowait(None)
except queue.Full:
pass
if self.is_alive():
self.join(timeout=2)
# -----------------------------------------------------------------------------
# 3. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
# -----------------------------------------------------------------------------
class SlidingFilter(threading.Thread):
def __init__(
self,
ring_buffer: FilterRingBuffer,
n_chan=66,
srate=250,
window_sec=3,
step_sec=0.2
):
super().__init__(daemon=True)
# 核心参数
self.n_chan = n_chan
self.srate = srate
self.step_sec = step_sec # 200ms滑动步长
self.window_sec = window_sec # 3秒窗口
self.step_sec = step_sec # 200ms滑动步长
self.window_size = int(srate * window_sec) # 3秒点数250*3=750
self.step_size = int(srate * step_sec) # 200ms点数250*0.2=50
# 关联ZMQServer的环形缓存解耦仅依赖接口
self.ring_buffer = ring_buffer
# 线程控制
self.running = threading.Event()
self.running.set()
# 滤波结果回调(外部可注册,获取滤波后的数据)
self.filter_result_callback = None
# beta 每秒触发计数200ms步长5次 = 1s
self._beta_step_counter = 0
self._beta_steps_per_second = max(1, int(round(1.0 / step_sec))) # 5
# 预计算滤波器系数(仅执行一次)
self._init_filters()
# 独立的 Beta 计算线程(避免阻塞滤波主循环)
self._beta_thread = BetaPsdCalculator(fs=srate, window_size=self.window_size)
def start(self):
"""同时启动 Beta 计算线程和滤波主线程"""
self._beta_thread.start()
super().start()
def set_beta_broadcast_callback(self, callback):
"""注册 Beta PSD 广播回调函数"""
self._beta_thread.beta_broadcast_callback = callback
def _init_filters(self):
"""预计算所有滤波器系数(仅执行一次)"""
# 50Hz工频陷波Q=30工业标准
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
# 0.5~45Hz带通FIR65阶线性相位
self.b_bp = signal.firwin(
numtaps=65,
cutoff=[0.5/(self.srate/2), 45/(self.srate/2)],
pass_zero=False,
window='hamming'
)
self.a_bp = np.array([1.0])
def _filter_window_data(self, window_data):
"""对3秒窗口数据执行滤波返回 (无边界效应的200ms数据, 完整3s滤波数据)"""
# 零相位滤波(无延迟,无边界效应)
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
# 提取倒数第二个200ms的数据完全避开两端边界效应
# 窗口长度750步长50 → start=750-100=650end=750-50=700
start_idx = self.window_size - 2 * self.step_size
end_idx = self.window_size - self.step_size
output_data = filtered[:, start_idx:end_idx].copy()
return output_data, filtered
def run(self):
"""线程主逻辑精确200ms触发一次滤波"""
interval = self.step_sec # 200ms = 0.2秒
next_run_time = time.perf_counter()
while self.running.is_set():
# 1. 精确定时等待
current_time = time.perf_counter()
if current_time < next_run_time:
time.sleep(next_run_time - current_time)
next_run_time += interval
else:
algo_log("滤波耗时超过200ms定时偏移", level='debug')
next_run_time = time.perf_counter() + interval
# ========== 新增核心判断:无新数据则直接跳过 ==========
if not self.ring_buffer.check_and_clear_new_data():
# 无新数据,不执行滤波、不发送数据
continue
# 2. 有新数据,才执行原有滤波逻辑
try:
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
if window_data is None:
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}", level='debug')
continue
filtered_data, filtered_full = self._filter_window_data(window_data)
# algo_log(f"滤波后{filtered_data.shape}数据", level='debug')
# ========== beta_psd 每秒计算一次Fp1/Fp2通道索引 0/1==========
self._beta_step_counter += 1
if self._beta_step_counter >= self._beta_steps_per_second:
self._beta_step_counter = 0
# 仅推送数据到队列,不阻塞等待计算完成
self._beta_thread.push_data(filtered_full[:2, :].copy())
if self.filter_result_callback is not None:
self.filter_result_callback(filtered_data[:64, :])
except Exception as e:
algo_log(f"滤波执行异常: {e}", level='error')
def set_result_callback(self, callback):
"""注册滤波结果回调函数"""
self.filter_result_callback = callback
def stop(self):
"""停止滤波线程和 Beta 计算线程"""
self._beta_thread.stop()
self.running.clear()
if self.is_alive():
self.join(timeout=1)
if self.is_alive():
algo_log("警告滤波线程在1秒内未正常退出可能存在阻塞操作", level="WARNING")
algo_log("滤波线程已停止")