add filter process
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
数据滤波模块
|
||||
"""
|
||||
import numpy as np
|
||||
import time
|
||||
import threading
|
||||
from scipy import signal
|
||||
from logs.log import algo_log
|
||||
@@ -10,7 +11,7 @@ from logs.log import algo_log
|
||||
class FilterRingBuffer:
|
||||
def __init__(self, n_chan, n_points):
|
||||
"""
|
||||
初始化纯数据环形缓存
|
||||
初始化纯数据环形缓存(线程安全)
|
||||
:param n_chan: 通道数
|
||||
:param n_points: 总缓存点数(与paradigmRingBuffer参数完全一致)
|
||||
"""
|
||||
@@ -18,11 +19,9 @@ class FilterRingBuffer:
|
||||
self.n_points = n_points
|
||||
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
|
||||
self.current_ptr = 0 # 写入指针
|
||||
self.current_ptr = 0 # 写入指针:指向下一个要写入的位置
|
||||
self.total_samples = 0 # 已写入总点数
|
||||
|
||||
# 线程安全锁(多线程环境必须)
|
||||
self.lock = threading.Lock()
|
||||
self.lock = threading.Lock() # 线程安全锁
|
||||
|
||||
def appendBuffer(self, data):
|
||||
"""
|
||||
@@ -33,8 +32,8 @@ class FilterRingBuffer:
|
||||
n = data.shape[1]
|
||||
if n == 0:
|
||||
return
|
||||
|
||||
# 环形写入逻辑
|
||||
|
||||
# 环形写入逻辑:指针到末尾则绕回
|
||||
write_end = self.current_ptr + n
|
||||
if write_end <= self.n_points:
|
||||
self.buffer[:, self.current_ptr:write_end] = data
|
||||
@@ -42,14 +41,15 @@ class FilterRingBuffer:
|
||||
split = self.n_points - self.current_ptr
|
||||
self.buffer[:, self.current_ptr:] = data[:, :split]
|
||||
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
||||
|
||||
# 更新指针和计数
|
||||
|
||||
# 更新指针(取模保证环形)和计数(不超过缓存总长度)
|
||||
self.current_ptr = write_end % self.n_points
|
||||
self.total_samples = min(self.total_samples + n, self.n_points)
|
||||
|
||||
def getData(self, count):
|
||||
"""
|
||||
从读指针位置读取count个点(与paradigmRingBuffer接口一致)
|
||||
从最新位置向前读取count个点(环形读取)
|
||||
核心逻辑:current_ptr是下一个写入位置 → 最新数据在current_ptr之前
|
||||
:param count: 读取点数
|
||||
:return: np.ndarray, shape=(n_chan, count)
|
||||
"""
|
||||
@@ -57,14 +57,15 @@ class FilterRingBuffer:
|
||||
count = min(count, self.total_samples)
|
||||
if count == 0:
|
||||
return np.zeros((self.n_chan, 0))
|
||||
|
||||
# 环形读取逻辑(与paradigmRingBuffer完全相同)
|
||||
|
||||
# 环形读取:end是当前写入指针(最新数据的下一位),start是end - count
|
||||
end = self.current_ptr
|
||||
start = end - count
|
||||
if start >= 0:
|
||||
return self.buffer[:, start:end].copy()
|
||||
else:
|
||||
part1 = self.buffer[:, start:]
|
||||
# 跨环形边界:前半部分从缓存末尾取,后半部分从开头取
|
||||
part1 = self.buffer[:, start:] # start为负,等价于n_points + start
|
||||
part2 = self.buffer[:, :end]
|
||||
return np.concatenate((part1, part2), axis=1)
|
||||
|
||||
@@ -72,7 +73,7 @@ class FilterRingBuffer:
|
||||
"""
|
||||
扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口)
|
||||
:param n: 点数
|
||||
:return: np.ndarray, shape=(n_chan, n)
|
||||
:return: np.ndarray, shape=(n_chan, n) | None(数据不足时)
|
||||
"""
|
||||
with self.lock:
|
||||
if self.total_samples < n:
|
||||
@@ -93,43 +94,37 @@ class FilterRingBuffer:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||||
# 可替换任意缓存实现,只要实现appendBuffer、get_latest_n_points接口
|
||||
# -----------------------------------------------------------------------------
|
||||
class SlidingFilter:
|
||||
class SlidingFilter(threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
ring_buffer: FilterRingBuffer,
|
||||
n_chan=66,
|
||||
srate=250,
|
||||
buffer_sec=5,
|
||||
window_sec=3,
|
||||
step_sec=0.2,
|
||||
step_sec=0.2, # 200ms滑动步长
|
||||
packet_size=5
|
||||
):
|
||||
"""
|
||||
初始化滑动滤波器
|
||||
:param n_chan: 通道数
|
||||
:param srate: 采样率
|
||||
:param buffer_sec: 总缓存时长(秒)
|
||||
:param window_sec: 滤波窗口时长(秒)
|
||||
:param step_sec: 滑动步长/输出时长(秒)
|
||||
:param packet_size: 每包数据点数(20ms一包=5点)
|
||||
"""
|
||||
super().__init__(daemon=True)
|
||||
# 核心参数
|
||||
self.n_chan = n_chan
|
||||
self.srate = srate
|
||||
self.buffer_size = int(srate * buffer_sec)
|
||||
self.window_size = int(srate * window_sec)
|
||||
self.step_size = int(srate * step_sec)
|
||||
self.step_sec = step_sec # 200ms滑动步长
|
||||
self.window_sec = window_sec # 3秒窗口
|
||||
self.step_sec = step_sec # 200ms滑动步长
|
||||
self.window_size = int(srate * window_sec) # 3秒点数:250*3=750
|
||||
self.step_size = int(srate * step_sec) # 200ms点数:250*0.2=50
|
||||
self.packet_size = packet_size
|
||||
|
||||
# 初始化纯数据缓存(解耦核心)
|
||||
self.buffer = FilterRingBuffer(n_chan, self.buffer_size)
|
||||
|
||||
# 滤波触发计数器
|
||||
self.packet_count = 0
|
||||
self.ready_to_filter = False
|
||||
|
||||
# 预计算滤波器系数
|
||||
|
||||
# 关联ZMQServer的环形缓存(解耦:仅依赖接口)
|
||||
self.ring_buffer = ring_buffer
|
||||
# 线程控制
|
||||
self.running = threading.Event()
|
||||
self.running.set()
|
||||
# 滤波结果回调(外部可注册,获取滤波后的数据)
|
||||
self.filter_result_callback = None
|
||||
|
||||
# 预计算滤波器系数(仅执行一次)
|
||||
self._init_filters()
|
||||
|
||||
def _init_filters(self):
|
||||
@@ -145,65 +140,60 @@ class SlidingFilter:
|
||||
)
|
||||
self.a_bp = np.array([1.0])
|
||||
|
||||
def append_and_check_trigger(self, raw_data):
|
||||
"""
|
||||
追加单包原始数据并检查是否触发滤波
|
||||
:param raw_data: 上位机原始数据,shape=(packet_size, n_chan)
|
||||
:return: bool: 是否触发本次滤波
|
||||
"""
|
||||
# 转置为标准格式:(通道数, 点数)
|
||||
data = raw_data.T.astype(np.float64)
|
||||
|
||||
# 写入缓存(纯缓存操作)
|
||||
self.buffer.appendBuffer(data)
|
||||
|
||||
# 更新包计数器
|
||||
self.packet_count += 1
|
||||
|
||||
# 检查滤波触发条件:数据≥窗口长度 且 累计满一个步长的包数
|
||||
packets_per_step = int(self.step_size / self.packet_size) # 10包=200ms
|
||||
if (self.buffer.GetDataLenCount() >= self.window_size
|
||||
and self.packet_count >= packets_per_step):
|
||||
self.packet_count = 0
|
||||
self.ready_to_filter = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def filter_and_get_output(self):
|
||||
"""
|
||||
执行滤波并返回无边界效应的输出数据
|
||||
:return: np.ndarray: 滤波后数据,shape=(n_chan, step_size)
|
||||
"""
|
||||
if not self.ready_to_filter:
|
||||
return None
|
||||
|
||||
# 获取最新的完整滤波窗口数据
|
||||
window_data = self.buffer.get_latest_n_points(self.window_size)
|
||||
if window_data is None:
|
||||
self.ready_to_filter = False
|
||||
return None
|
||||
|
||||
def _filter_window_data(self, window_data):
|
||||
"""对3秒窗口数据执行滤波,返回无边界效应的200ms数据"""
|
||||
# 零相位滤波(无延迟,无边界效应)
|
||||
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
||||
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
||||
filtered = signal.filtfilt(self.b_bp, self.a_bp, filtered, axis=-1)
|
||||
|
||||
# 提取倒数第二个步长的数据(完全避开两端边界效应)
|
||||
|
||||
# 提取倒数第二个200ms的数据(完全避开两端边界效应)
|
||||
# 窗口长度750,步长50 → start=750-100=650,end=750-50=700
|
||||
start_idx = self.window_size - 2 * self.step_size
|
||||
end_idx = self.window_size - self.step_size
|
||||
output_data = filtered[:, start_idx:end_idx].copy()
|
||||
|
||||
# 重置触发标志
|
||||
self.ready_to_filter = False
|
||||
|
||||
return output_data
|
||||
|
||||
def reset(self):
|
||||
"""重置滤波器和缓存"""
|
||||
self.buffer.resetAllPara()
|
||||
self.packet_count = 0
|
||||
self.ready_to_filter = False
|
||||
def run(self):
|
||||
"""线程主逻辑:精确200ms触发一次滤波"""
|
||||
# 精确定时核心:基于perf_counter计算下一次执行时间,补偿sleep误差
|
||||
interval = self.step_sec # 200ms = 0.2秒
|
||||
next_run_time = time.perf_counter()
|
||||
|
||||
def get_buffer_length(self):
|
||||
"""获取当前缓存数据长度"""
|
||||
return self.buffer.GetDataLenCount()
|
||||
while self.running.is_set():
|
||||
# 1. 等待到下一次执行时间(精确定时)
|
||||
current_time = time.perf_counter()
|
||||
if current_time < next_run_time:
|
||||
time.sleep(next_run_time - current_time)
|
||||
next_run_time += interval # 补偿:下次执行时间基于上一次目标时间
|
||||
else:
|
||||
# 若超时(如滤波耗时超过200ms),重置下一次时间(避免累积误差)
|
||||
algo_log("滤波耗时超过200ms,定时偏移", level='debug')
|
||||
next_run_time = time.perf_counter() + interval
|
||||
|
||||
# 2. 执行滤波逻辑
|
||||
try:
|
||||
# 获取最新的3秒窗口数据
|
||||
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
|
||||
if window_data is None:
|
||||
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}点", level='debug')
|
||||
continue
|
||||
|
||||
# 滤波并提取无边界效应的200ms数据
|
||||
filtered_data = self._filter_window_data(window_data)
|
||||
|
||||
# 回调返回结果(外部可处理)
|
||||
if self.filter_result_callback is not None:
|
||||
self.filter_result_callback(filtered_data[:64, :]) # 只发送前64通道数据
|
||||
|
||||
except Exception as e:
|
||||
algo_log(f"滤波执行异常: {e}", level='error')
|
||||
|
||||
def set_result_callback(self, callback):
|
||||
"""注册滤波结果回调函数"""
|
||||
self.filter_result_callback = callback
|
||||
|
||||
def stop(self):
|
||||
"""停止滤波线程"""
|
||||
self.running.clear()
|
||||
self.join(timeout=1)
|
||||
|
||||
Reference in New Issue
Block a user