208 lines
7.3 KiB
Python
208 lines
7.3 KiB
Python
# -*-coding:utf-8 -*-
|
||
"""
|
||
数据滤波模块
|
||
"""
|
||
import numpy as np
|
||
import threading
|
||
from logs.log import algo_log
|
||
|
||
class FilterRingBuffer:
|
||
def __init__(self, n_chan, n_points):
|
||
"""
|
||
初始化纯数据环形缓存
|
||
:param n_chan: 通道数
|
||
:param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致)
|
||
"""
|
||
self.n_chan = n_chan
|
||
self.n_points = n_points
|
||
|
||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||
self.current_ptr = 0 # 写入指针
|
||
self.total_samples = 0 # 已写入总点数
|
||
|
||
# 线程安全锁(多线程环境必须)
|
||
self.lock = threading.Lock()
|
||
|
||
def appendBuffer(self, data):
|
||
"""
|
||
追加数据到缓存(与paradigmRingBuffer接口一致)
|
||
:param data: 输入数据,shape=(n_chan, n_samples)
|
||
"""
|
||
with self.lock:
|
||
n = data.shape[1]
|
||
if n == 0:
|
||
return
|
||
|
||
# 环形写入逻辑
|
||
write_end = self.current_ptr + n
|
||
if write_end <= self.n_points:
|
||
self.buffer[:, self.current_ptr:write_end] = data
|
||
else:
|
||
split = self.n_points - self.current_ptr
|
||
self.buffer[:, self.current_ptr:] = data[:, :split]
|
||
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
||
|
||
# 更新指针和计数
|
||
self.current_ptr = write_end % self.n_points
|
||
self.total_samples = min(self.total_samples + n, self.n_points)
|
||
|
||
def getData(self, count):
|
||
"""
|
||
从读指针位置读取count个点(与paradigmRingBuffer接口一致)
|
||
:param count: 读取点数
|
||
:return: np.ndarray, shape=(n_chan, count)
|
||
"""
|
||
with self.lock:
|
||
count = min(count, self.total_samples)
|
||
if count == 0:
|
||
return np.zeros((self.n_chan, 0))
|
||
|
||
# 环形读取逻辑(与paradigmRingBuffer完全相同)
|
||
end = self.current_ptr
|
||
start = end - count
|
||
if start >= 0:
|
||
return self.buffer[:, start:end].copy()
|
||
else:
|
||
part1 = self.buffer[:, start:]
|
||
part2 = self.buffer[:, :end]
|
||
return np.concatenate((part1, part2), axis=1)
|
||
|
||
def get_latest_n_points(self, n):
|
||
"""
|
||
扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口)
|
||
:param n: 点数
|
||
:return: np.ndarray, shape=(n_chan, n)
|
||
"""
|
||
with self.lock:
|
||
if self.total_samples < n:
|
||
return None
|
||
return self.getData(n)
|
||
|
||
def GetDataLenCount(self):
|
||
"""获取当前缓存总点数(兼容原有接口)"""
|
||
with self.lock:
|
||
return self.total_samples
|
||
|
||
def resetAllPara(self):
|
||
"""重置所有缓存和指针(兼容原有接口)"""
|
||
with self.lock:
|
||
self.buffer.fill(0.0)
|
||
self.current_ptr = 0
|
||
self.total_samples = 0
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||
# 可替换任意缓存实现,只要实现appendBuffer、get_latest_n_points接口
|
||
# -----------------------------------------------------------------------------
|
||
class SlidingFilter:
|
||
def __init__(
|
||
self,
|
||
n_chan=66,
|
||
srate=250,
|
||
buffer_sec=5,
|
||
window_sec=3,
|
||
step_sec=0.2,
|
||
packet_size=5
|
||
):
|
||
"""
|
||
初始化滑动滤波器
|
||
:param n_chan: 通道数
|
||
:param srate: 采样率
|
||
:param buffer_sec: 总缓存时长(秒)
|
||
:param window_sec: 滤波窗口时长(秒)
|
||
:param step_sec: 滑动步长/输出时长(秒)
|
||
:param packet_size: 每包数据点数(20ms一包=5点)
|
||
"""
|
||
# 核心参数
|
||
self.n_chan = n_chan
|
||
self.srate = srate
|
||
self.buffer_size = int(srate * buffer_sec)
|
||
self.window_size = int(srate * window_sec)
|
||
self.step_size = int(srate * step_sec)
|
||
self.packet_size = packet_size
|
||
|
||
# 初始化纯数据缓存(解耦核心)
|
||
self.buffer = FilterRingBuffer(n_chan, self.buffer_size)
|
||
|
||
# 滤波触发计数器
|
||
self.packet_count = 0
|
||
self.ready_to_filter = False
|
||
|
||
# 预计算滤波器系数
|
||
self._init_filters()
|
||
|
||
def _init_filters(self):
|
||
"""预计算所有滤波器系数(仅执行一次)"""
|
||
# 50Hz工频陷波(Q=30,工业标准)
|
||
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
|
||
# 8~30Hz带通FIR(65阶,线性相位)
|
||
self.b_bp = signal.firwin(
|
||
numtaps=65,
|
||
cutoff=[8/(self.srate/2), 30/(self.srate/2)],
|
||
pass_zero=False,
|
||
window='hamming'
|
||
)
|
||
self.a_bp = np.array([1.0])
|
||
|
||
def append_and_check_trigger(self, raw_data):
|
||
"""
|
||
追加单包原始数据并检查是否触发滤波
|
||
:param raw_data: 上位机原始数据,shape=(packet_size, n_chan)
|
||
:return: bool: 是否触发本次滤波
|
||
"""
|
||
# 转置为标准格式:(通道数, 点数)
|
||
data = raw_data.T.astype(np.float64)
|
||
|
||
# 写入缓存(纯缓存操作)
|
||
self.buffer.appendBuffer(data)
|
||
|
||
# 更新包计数器
|
||
self.packet_count += 1
|
||
|
||
# 检查滤波触发条件:数据≥窗口长度 且 累计满一个步长的包数
|
||
packets_per_step = int(self.step_size / self.packet_size) # 10包=200ms
|
||
if (self.buffer.GetDataLenCount() >= self.window_size
|
||
and self.packet_count >= packets_per_step):
|
||
self.packet_count = 0
|
||
self.ready_to_filter = True
|
||
return True
|
||
return False
|
||
|
||
def filter_and_get_output(self):
|
||
"""
|
||
执行滤波并返回无边界效应的输出数据
|
||
:return: np.ndarray: 滤波后数据,shape=(n_chan, step_size)
|
||
"""
|
||
if not self.ready_to_filter:
|
||
return None
|
||
|
||
# 获取最新的完整滤波窗口数据
|
||
window_data = self.buffer.get_latest_n_points(self.window_size)
|
||
if window_data is None:
|
||
self.ready_to_filter = False
|
||
return None
|
||
|
||
# 零相位滤波(无延迟,无边界效应)
|
||
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
||
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
||
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
|
||
|
||
# 提取倒数第二个步长的数据(完全避开两端边界效应)
|
||
start_idx = self.window_size - 2 * self.step_size
|
||
end_idx = self.window_size - self.step_size
|
||
output_data = filtered[:, start_idx:end_idx].copy()
|
||
|
||
# 重置触发标志
|
||
self.ready_to_filter = False
|
||
|
||
return output_data
|
||
|
||
def reset(self):
|
||
"""重置滤波器和缓存"""
|
||
self.buffer.resetAllPara()
|
||
self.packet_count = 0
|
||
self.ready_to_filter = False
|
||
|
||
def get_buffer_length(self):
|
||
"""获取当前缓存数据长度"""
|
||
return self.buffer.GetDataLenCount() |