Files
bci_algo/Zmq/dataBuffer.py
2026-06-06 14:40:07 +08:00

121 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*-coding:utf-8 -*-
"""
范式buffer和滤波buffer, 以及滤波函数
"""
import numpy as np
from scipy import signal
import threading
class ParadigmRingBuffer:
def __init__(self, n_chan, n_points):
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points))
self.currentPtr = 0
self.readPtr = 0
self.nUpdate = 0
self.rawData = np.zeros((n_chan, 1))
## append buffer and update current pointer
def appendBuffer(self, data):
if self.nUpdate == self.n_points:
raise Exception("Buffer is full")
n = data.shape[1]
# 计算可以写入的元素数量
write_count = min(self.n_points - self.nUpdate, n)
# 写入新数据
self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count]
# 更新结束指针
self.currentPtr = (self.currentPtr + write_count) % self.n_points
# 更新大小
self.nUpdate += write_count
## get data from buffer
def getData(self, count=50):
# 确保不会尝试读取超过缓冲区当前大小的数据
count = min(count, self.nUpdate)
# 计算读取结束后的下一个位置
next_read_ptr = (self.readPtr + count) % self.n_points
if self.readPtr + count <= self.n_points:
# 情况 1不环绕数据是连续的
end_index = next_read_ptr if next_read_ptr != 0 else self.n_points
data = self.buffer[:, self.readPtr:end_index]
else:
# 情况 2发生环绕数据被分成两部分
# 第一部分:从 readPtr 到缓冲区末尾
part1 = self.buffer[:, self.readPtr:]
# 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点
part2 = self.buffer[:, :next_read_ptr]
# 将两部分在列方向上拼接
data = np.concatenate((part1, part2), axis=1)
# 更新读指针
self.readPtr = next_read_ptr
# 更新大小
self.nUpdate -= count
return data
def GetDataLenCount(self):
'''
获取最新缓存中每个通道的数量
@return:
'''
return self.nUpdate
# ========== 各范式数据访问接口 ==========
def get_MIData(self):
"""获取MI导联数据 (21通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_SSMVEPData(self):
"""获取SSMVEP导联数据 (8通道 + 事件)"""
data = self.getData(self.GetDataLenCount())
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def getDataViaSSVEP(self, count):
"""获取SSVEP数据 (8通道 + 事件)"""
data = self.getData(count)
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_concentrateData(self, count):
"""获取专注力数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
def get_blinkData(self, count):
"""获取眨眼数据 (2通道)"""
data = self.getData(count)
rows_to_extract = [0, 1]
row_to_select = np.array(rows_to_extract)
if data.shape[1] > 0:
return data[row_to_select, :]
return np.zeros((len(rows_to_extract), 0))
# reset buffer
def resetAllPara(self):
self.nUpdate = 0
self.currentPtr = 0
self.readPtr = 0 # add by lizhenhua 清空读指针
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区