Files
bci_algo/Zmq/filterProcess.py
2026-06-12 14:30:11 +08:00

313 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*-coding:utf-8 -*-
"""
数据滤波模块
"""
import numpy as np
import time
import threading
import queue
from scipy import signal
from logs.log import algo_log
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Tools.beta_calculate import Beta_Calculate
class FilterRingBuffer:
def __init__(self, n_chan, n_points):
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
self.current_ptr = 0
self.total_samples = 0
self.lock = threading.Lock() # 仅保护元数据
self.has_new_data = False
def appendBuffer(self, data):
n = data.shape[1]
if n == 0:
return
# 仅加锁读取/更新元数据
with self.lock:
old_ptr = self.current_ptr
new_ptr = (old_ptr + n) % self.n_points
new_total = min(self.total_samples + n, self.n_points)
self.has_new_data = True
# 数组写入(耗时操作,移出锁外)
write_end = old_ptr + n
if write_end <= self.n_points:
self.buffer[:, old_ptr:write_end] = data
else:
split = self.n_points - old_ptr
self.buffer[:, old_ptr:] = data[:, :split]
self.buffer[:, :write_end - self.n_points] = data[:, split:]
# 再次加锁更新最终元数据
with self.lock:
self.current_ptr = new_ptr
self.total_samples = new_total
# ========== 新增:获取&清空新数据标记的方法 ==========
def check_and_clear_new_data(self):
"""检查是否有新数据,并一次性清空标记(消费后重置)"""
with self.lock:
flag = self.has_new_data
if flag:
self.has_new_data = False
return flag
def getData(self, count):
# 加锁获取最新元数据
with self.lock:
count = min(count, self.total_samples)
if count == 0:
return np.zeros((self.n_chan, 0))
end = self.current_ptr
start = end - count
# 数据读取、切片、拼接(无锁)
if start >= 0:
res = self.buffer[:, start:end].copy()
else:
part1 = self.buffer[:, start:]
part2 = self.buffer[:, :end]
res = np.concatenate((part1, part2), axis=1).copy()
return res
def get_latest_n_points(self, n):
with self.lock:
if self.total_samples < n:
return None
return self.getData(n)
def GetDataLenCount(self):
with self.lock:
return self.total_samples
def resetAllPara(self):
with self.lock:
self.buffer.fill(0.0)
self.current_ptr = 0
self.total_samples = 0
self.has_new_data = False # 重置时清空新数据标记
# -----------------------------------------------------------------------------
# 2. 独立 Beta PSD 计算线程(避免阻塞滤波主循环的 200ms 定时)
# -----------------------------------------------------------------------------
class BetaPsdCalculator(threading.Thread):
"""独立的 Beta PSD 计算线程,使用队列与滤波主线程解耦"""
def __init__(self, fs=250, window_size=750):
super().__init__(daemon=True)
self.fs = fs
self.window_size = window_size
self._beta_calc = Beta_Calculate(Threshold_value_low=0, Threshold_value_high=0, fs=fs)
self._input_queue = queue.Queue(maxsize=2)
self._running = threading.Event()
self._running.set()
self._latest_beta = None
self._beta_lock = threading.Lock()
self.beta_broadcast_callback = None
def push_data(self, data):
"""供外部调用的线程安全数据推送接口"""
try:
self._input_queue.put_nowait(data)
except queue.Full:
try:
self._input_queue.get_nowait()
except queue.Empty:
pass
try:
self._input_queue.put_nowait(data)
except queue.Full:
pass
def get_latest_beta(self):
"""获取最新的 beta 值(线程安全)"""
with self._beta_lock:
return self._latest_beta
def run(self):
while self._running.is_set():
try:
data = self._input_queue.get(timeout=1.5)
if data is None:
break
try:
beta_psd, _, _ = self._beta_calc.calculate_all(
data, fs=self.fs, nperseg=min(self.window_size, data.shape[1])
)
with self._beta_lock:
self._latest_beta = round(float(beta_psd), 3)
if self.beta_broadcast_callback is not None:
self.beta_broadcast_callback(self._latest_beta)
except Exception as e:
algo_log(f"Beta PSD 计算异常: {e}", level='error')
except queue.Empty:
pass
def stop(self):
"""停止计算线程"""
self._running.clear()
try:
self._input_queue.put_nowait(None)
except queue.Full:
pass
if self.is_alive():
self.join(timeout=2)
# -----------------------------------------------------------------------------
# 3. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
# -----------------------------------------------------------------------------
class SlidingFilter(threading.Thread):
def __init__(
self,
ring_buffer: FilterRingBuffer,
n_chan=66,
srate=250,
window_sec=3,
step_sec=0.2
):
super().__init__(daemon=True)
# 核心参数
self.n_chan = n_chan
self.srate = srate
self.step_sec = step_sec # 200ms滑动步长
self.window_sec = window_sec # 3秒窗口
self.step_sec = step_sec # 200ms滑动步长
self.window_size = int(srate * window_sec) # 3秒点数250*3=750
self.step_size = int(srate * step_sec) # 200ms点数250*0.2=50
# 关联ZMQServer的环形缓存解耦仅依赖接口
self.ring_buffer = ring_buffer
# 线程控制
self.running = threading.Event()
self.running.set()
# 滤波结果回调(外部可注册,获取滤波后的数据)
self.filter_result_callback = None
# beta 每秒触发计数200ms步长5次 = 1s
self._beta_step_counter = 0
self._beta_steps_per_second = max(1, int(round(1.0 / step_sec))) # 5
self.slide_window = None # 滑动窗口缓存 (n_chan, window_size)
self.slide_ready = False # 窗口是否已填满初始数据
# 预计算滤波器系数(仅执行一次)
self._init_filters()
# 独立的 Beta 计算线程(避免阻塞滤波主循环)
self._beta_thread = BetaPsdCalculator(fs=srate, window_size=self.window_size)
def start(self):
"""同时启动 Beta 计算线程和滤波主线程"""
self._beta_thread.start()
super().start()
def set_beta_broadcast_callback(self, callback):
"""注册 Beta PSD 广播回调函数"""
self._beta_thread.beta_broadcast_callback = callback
def _init_filters(self):
"""预计算所有滤波器系数(仅执行一次)"""
# 50Hz工频陷波Q=30工业标准
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
# 0.5~45Hz带通FIR65阶线性相位
self.b_bp = signal.firwin(
numtaps=65,
cutoff=[0.5/(self.srate/2), 45/(self.srate/2)],
pass_zero=False,
window='hamming'
)
self.a_bp = np.array([1.0])
def _filter_window_data(self, window_data):
"""对3秒窗口数据执行滤波返回 (无边界效应的200ms数据, 完整3s滤波数据)"""
# 零相位滤波(无延迟,无边界效应)
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
# 提取倒数第二个200ms的数据完全避开两端边界效应
# 窗口长度750步长50 → start=750-100=650end=750-50=700
start_idx = self.window_size - 2 * self.step_size
end_idx = self.window_size - self.step_size
output_data = filtered[:, start_idx:end_idx].copy()
return output_data, filtered
def run(self):
"""线程主逻辑精确200ms触发一次滤波"""
interval = self.step_sec # 0.2s
# 以启动时刻为绝对时间基准(核心改动)
base_time = time.perf_counter()
frame_count = 0 # 帧计数器,用于对齐时序
while self.running.is_set():
# 计算理论执行时刻:严格按帧序号 × 步长
expect_time = base_time + frame_count * interval
current_time = time.perf_counter()
# 精确定时等待
if current_time < expect_time:
time.sleep(expect_time - current_time)
else:
# 处理超时:仅告警,不重置基准(防止累积偏移)
algo_log(f"滤波任务超时,偏移 {(current_time - expect_time)*1000:.1f} ms", level='debug')
frame_count += 1 # 帧序号自增,保证周期绝对稳定
if not self.ring_buffer.check_and_clear_new_data():
# 无新数据,不执行滤波、不发送数据
continue
# ========== 原有滤波逻辑 ==========
try:
if not self.slide_ready:
# 阶段1首次填满3s初始窗口
full_data = self.ring_buffer.get_latest_n_points(self.window_size)
if full_data is None:
algo_log("初始窗口数据不足", level='debug')
continue
self.slide_window = full_data
self.slide_ready = True
else:
# 阶段2正常滑动 → 取最新50个新点增量拼接
new_step_data = self.ring_buffer.get_latest_n_points(self.step_size)
if new_step_data is None:
algo_log("滑动步长数据不足", level='debug')
continue
# 增量滑动丢弃前50点拼接新50点标准滑动窗口
self.slide_window = np.hstack([
self.slide_window[:, self.step_size:],
new_step_data
])
filtered_data, filtered_full = self._filter_window_data(self.slide_window[:64, :])
# Beta PSD 每秒计算一次
self._beta_step_counter += 1
if self._beta_step_counter >= self._beta_steps_per_second:
self._beta_step_counter = 0
self._beta_thread.push_data(filtered_full[:2, :])
if self.filter_result_callback is not None:
self.filter_result_callback(filtered_data)
except Exception as e:
algo_log(f"滤波执行异常: {e}", level='error')
def set_result_callback(self, callback):
"""注册滤波结果回调函数"""
self.filter_result_callback = callback
def stop(self):
"""停止滤波线程和 Beta 计算线程"""
self._beta_thread.stop()
self.running.clear()
if self.is_alive():
self.join(timeout=1)
if self.is_alive():
algo_log("警告滤波线程在1秒内未正常退出可能存在阻塞操作", level="WARNING")
algo_log("滤波线程已停止")