Files
bci_algo/Zmq/filterProcess.py
2026-06-06 09:16:49 +08:00

208 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*-coding:utf-8 -*-
"""
数据滤波模块
"""
import numpy as np
import threading
from logs.log import algo_log
class FilterRingBuffer:
def __init__(self, n_chan, n_points):
"""
初始化纯数据环形缓存
:param n_chan: 通道数
:param n_points: 总缓存点数与paradigmRingBuffer参数完全一致
"""
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
self.current_ptr = 0 # 写入指针
self.total_samples = 0 # 已写入总点数
# 线程安全锁(多线程环境必须)
self.lock = threading.Lock()
def appendBuffer(self, data):
"""
追加数据到缓存与paradigmRingBuffer接口一致
:param data: 输入数据shape=(n_chan, n_samples)
"""
with self.lock:
n = data.shape[1]
if n == 0:
return
# 环形写入逻辑
write_end = self.current_ptr + n
if write_end <= self.n_points:
self.buffer[:, self.current_ptr:write_end] = data
else:
split = self.n_points - self.current_ptr
self.buffer[:, self.current_ptr:] = data[:, :split]
self.buffer[:, :write_end - self.n_points] = data[:, split:]
# 更新指针和计数
self.current_ptr = write_end % self.n_points
self.total_samples = min(self.total_samples + n, self.n_points)
def getData(self, count):
"""
从读指针位置读取count个点与paradigmRingBuffer接口一致
:param count: 读取点数
:return: np.ndarray, shape=(n_chan, count)
"""
with self.lock:
count = min(count, self.total_samples)
if count == 0:
return np.zeros((self.n_chan, 0))
# 环形读取逻辑与paradigmRingBuffer完全相同
end = self.current_ptr
start = end - count
if start >= 0:
return self.buffer[:, start:end].copy()
else:
part1 = self.buffer[:, start:]
part2 = self.buffer[:, :end]
return np.concatenate((part1, part2), axis=1)
def get_latest_n_points(self, n):
"""
扩展方法获取最新的n个点不移动读指针用于滑动窗口
:param n: 点数
:return: np.ndarray, shape=(n_chan, n)
"""
with self.lock:
if self.total_samples < n:
return None
return self.getData(n)
def GetDataLenCount(self):
"""获取当前缓存总点数(兼容原有接口)"""
with self.lock:
return self.total_samples
def resetAllPara(self):
"""重置所有缓存和指针(兼容原有接口)"""
with self.lock:
self.buffer.fill(0.0)
self.current_ptr = 0
self.total_samples = 0
# -----------------------------------------------------------------------------
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
# 可替换任意缓存实现只要实现appendBuffer、get_latest_n_points接口
# -----------------------------------------------------------------------------
class SlidingFilter:
def __init__(
self,
n_chan=66,
srate=250,
buffer_sec=5,
window_sec=3,
step_sec=0.2,
packet_size=5
):
"""
初始化滑动滤波器
:param n_chan: 通道数
:param srate: 采样率
:param buffer_sec: 总缓存时长(秒)
:param window_sec: 滤波窗口时长(秒)
:param step_sec: 滑动步长/输出时长(秒)
:param packet_size: 每包数据点数20ms一包=5点
"""
# 核心参数
self.n_chan = n_chan
self.srate = srate
self.buffer_size = int(srate * buffer_sec)
self.window_size = int(srate * window_sec)
self.step_size = int(srate * step_sec)
self.packet_size = packet_size
# 初始化纯数据缓存(解耦核心)
self.buffer = FilterRingBuffer(n_chan, self.buffer_size)
# 滤波触发计数器
self.packet_count = 0
self.ready_to_filter = False
# 预计算滤波器系数
self._init_filters()
def _init_filters(self):
"""预计算所有滤波器系数(仅执行一次)"""
# 50Hz工频陷波Q=30工业标准
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
# 8~30Hz带通FIR65阶线性相位
self.b_bp = signal.firwin(
numtaps=65,
cutoff=[8/(self.srate/2), 30/(self.srate/2)],
pass_zero=False,
window='hamming'
)
self.a_bp = np.array([1.0])
def append_and_check_trigger(self, raw_data):
"""
追加单包原始数据并检查是否触发滤波
:param raw_data: 上位机原始数据shape=(packet_size, n_chan)
:return: bool: 是否触发本次滤波
"""
# 转置为标准格式:(通道数, 点数)
data = raw_data.T.astype(np.float64)
# 写入缓存(纯缓存操作)
self.buffer.appendBuffer(data)
# 更新包计数器
self.packet_count += 1
# 检查滤波触发条件:数据≥窗口长度 且 累计满一个步长的包数
packets_per_step = int(self.step_size / self.packet_size) # 10包=200ms
if (self.buffer.GetDataLenCount() >= self.window_size
and self.packet_count >= packets_per_step):
self.packet_count = 0
self.ready_to_filter = True
return True
return False
def filter_and_get_output(self):
"""
执行滤波并返回无边界效应的输出数据
:return: np.ndarray: 滤波后数据shape=(n_chan, step_size)
"""
if not self.ready_to_filter:
return None
# 获取最新的完整滤波窗口数据
window_data = self.buffer.get_latest_n_points(self.window_size)
if window_data is None:
self.ready_to_filter = False
return None
# 零相位滤波(无延迟,无边界效应)
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
# 提取倒数第二个步长的数据(完全避开两端边界效应)
start_idx = self.window_size - 2 * self.step_size
end_idx = self.window_size - self.step_size
output_data = filtered[:, start_idx:end_idx].copy()
# 重置触发标志
self.ready_to_filter = False
return output_data
def reset(self):
"""重置滤波器和缓存"""
self.buffer.resetAllPara()
self.packet_count = 0
self.ready_to_filter = False
def get_buffer_length(self):
"""获取当前缓存数据长度"""
return self.buffer.GetDataLenCount()