fitler buffer with lock
This commit is contained in:
@@ -17,34 +17,39 @@ class FilterRingBuffer:
|
|||||||
"""
|
"""
|
||||||
self.n_chan = n_chan
|
self.n_chan = n_chan
|
||||||
self.n_points = n_points
|
self.n_points = n_points
|
||||||
|
|
||||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||||
self.current_ptr = 0 # 写入指针:指向下一个要写入的位置
|
self.current_ptr = 0
|
||||||
self.total_samples = 0 # 已写入总点数
|
self.total_samples = 0
|
||||||
self.lock = threading.Lock() # 线程安全锁
|
self.lock = threading.Lock() # 仅保护元数据
|
||||||
|
|
||||||
def appendBuffer(self, data):
|
def appendBuffer(self, data):
|
||||||
"""
|
"""
|
||||||
追加数据到缓存(与paradigmRingBuffer接口一致)
|
追加数据到缓存(与paradigmRingBuffer接口一致)
|
||||||
:param data: 输入数据,shape=(n_chan, n_samples)
|
:param data: 输入数据,shape=(n_chan, n_samples)
|
||||||
"""
|
"""
|
||||||
with self.lock:
|
|
||||||
n = data.shape[1]
|
n = data.shape[1]
|
||||||
if n == 0:
|
if n == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 环形写入逻辑:指针到末尾则绕回
|
# -------- 第一步:仅加锁读取/更新元数据(持锁极短)--------
|
||||||
write_end = self.current_ptr + n
|
with self.lock:
|
||||||
|
old_ptr = self.current_ptr
|
||||||
|
new_ptr = (old_ptr + n) % self.n_points
|
||||||
|
new_total = min(self.total_samples + n, self.n_points)
|
||||||
|
|
||||||
|
# -------- 第二步:数组写入(耗时操作,移出锁外)--------
|
||||||
|
write_end = old_ptr + n
|
||||||
if write_end <= self.n_points:
|
if write_end <= self.n_points:
|
||||||
self.buffer[:, self.current_ptr:write_end] = data
|
self.buffer[:, old_ptr:write_end] = data
|
||||||
else:
|
else:
|
||||||
split = self.n_points - self.current_ptr
|
split = self.n_points - old_ptr
|
||||||
self.buffer[:, self.current_ptr:] = data[:, :split]
|
self.buffer[:, old_ptr:] = data[:, :split]
|
||||||
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
self.buffer[:, :write_end - self.n_points] = data[:, split:]
|
||||||
|
|
||||||
# 更新指针(取模保证环形)和计数(不超过缓存总长度)
|
# -------- 第三步:再次加锁更新最终元数据 --------
|
||||||
self.current_ptr = write_end % self.n_points
|
with self.lock:
|
||||||
self.total_samples = min(self.total_samples + n, self.n_points)
|
self.current_ptr = new_ptr
|
||||||
|
self.total_samples = new_total
|
||||||
|
|
||||||
def getData(self, count):
|
def getData(self, count):
|
||||||
"""
|
"""
|
||||||
@@ -53,39 +58,33 @@ class FilterRingBuffer:
|
|||||||
:param count: 读取点数
|
:param count: 读取点数
|
||||||
:return: np.ndarray, shape=(n_chan, count)
|
:return: np.ndarray, shape=(n_chan, count)
|
||||||
"""
|
"""
|
||||||
# with self.lock:
|
# -------- 第一步:加锁获取最新元数据(持锁极短)--------
|
||||||
|
with self.lock:
|
||||||
count = min(count, self.total_samples)
|
count = min(count, self.total_samples)
|
||||||
if count == 0:
|
if count == 0:
|
||||||
return np.zeros((self.n_chan, 0))
|
return np.zeros((self.n_chan, 0))
|
||||||
|
|
||||||
# 环形读取:end是当前写入指针(最新数据的下一位),start是end - count
|
|
||||||
end = self.current_ptr
|
end = self.current_ptr
|
||||||
start = end - count
|
start = end - count
|
||||||
|
|
||||||
if start >= 0:
|
if start >= 0:
|
||||||
return self.buffer[:, start:end].copy()
|
res = self.buffer[:, start:end].copy()
|
||||||
else:
|
else:
|
||||||
# 跨环形边界:前半部分从缓存末尾取,后半部分从开头取
|
part1 = self.buffer[:, start:]
|
||||||
part1 = self.buffer[:, start:] # start为负,等价于n_points + start
|
|
||||||
part2 = self.buffer[:, :end]
|
part2 = self.buffer[:, :end]
|
||||||
return np.concatenate((part1, part2), axis=1)
|
res = np.concatenate((part1, part2), axis=1).copy()
|
||||||
|
return res
|
||||||
|
|
||||||
def get_latest_n_points(self, n):
|
def get_latest_n_points(self, n):
|
||||||
"""
|
with self.lock:
|
||||||
扩展方法:获取最新的n个点(不移动读指针,用于滑动窗口)
|
|
||||||
:param n: 点数
|
|
||||||
:return: np.ndarray, shape=(n_chan, n) | None(数据不足时)
|
|
||||||
"""
|
|
||||||
if self.total_samples < n:
|
if self.total_samples < n:
|
||||||
return None
|
return None
|
||||||
return self.getData(n)
|
return self.getData(n)
|
||||||
|
|
||||||
def GetDataLenCount(self):
|
def GetDataLenCount(self):
|
||||||
"""获取当前缓存总点数(兼容原有接口)"""
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
return self.total_samples
|
return self.total_samples
|
||||||
|
|
||||||
def resetAllPara(self):
|
def resetAllPara(self):
|
||||||
"""重置所有缓存和指针(兼容原有接口)"""
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.buffer.fill(0.0)
|
self.buffer.fill(0.0)
|
||||||
self.current_ptr = 0
|
self.current_ptr = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user