# -*-coding:utf-8 -*- """ 范式buffer和滤波buffer, 以及滤波函数 """ import numpy as np from scipy import signal import threading from logs.log import algo_log class ParadigmRingBuffer: def __init__(self, n_chan, n_points): self.n_chan = n_chan self.n_points = n_points self.buffer = np.zeros((n_chan, n_points), dtype=np.float64) self.currentPtr = 0 self.readPtr = 0 self.nUpdate = 0 self.rawData = np.zeros((n_chan, 1)) ## append buffer and update current pointer def appendBuffer(self, data): if self.nUpdate == self.n_points: # raise Exception("Buffer is full") algo_log("ParadigmRingBuffer is full", record_once=True) n = data.shape[1] # 计算可以写入的元素数量 write_count = min(self.n_points - self.nUpdate, n) # 写入新数据 self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count] # 更新结束指针 self.currentPtr = (self.currentPtr + write_count) % self.n_points # 更新大小 self.nUpdate += write_count ## get data from buffer def getData(self, count=50): # 确保不会尝试读取超过缓冲区当前大小的数据 count = min(count, self.nUpdate) # 计算读取结束后的下一个位置 next_read_ptr = (self.readPtr + count) % self.n_points if self.readPtr + count <= self.n_points: # 情况 1:不环绕,数据是连续的 end_index = next_read_ptr if next_read_ptr != 0 else self.n_points data = self.buffer[:, self.readPtr:end_index] else: # 情况 2:发生环绕,数据被分成两部分 # 第一部分:从 readPtr 到缓冲区末尾 part1 = self.buffer[:, self.readPtr:] # 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点 part2 = self.buffer[:, :next_read_ptr] # 将两部分在列方向上拼接 data = np.concatenate((part1, part2), axis=1) # 更新读指针 self.readPtr = next_read_ptr # 更新大小 self.nUpdate -= count return data def GetDataLenCount(self): ''' 获取最新缓存中每个通道的数量 @return: ''' return self.nUpdate # ========== 各范式数据访问接口 ========== def get_MIData(self): """获取MI导联数据 (21通道 + 事件)""" data = self.getData(self.GetDataLenCount()) rows_to_extract = [8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21, 64, 65] row_to_select = np.array(rows_to_extract) if data.shape[1] > 0: return data[row_to_select, :] return np.zeros((len(rows_to_extract), 0)) def get_SSMVEPData(self): """获取SSMVEP导联数据 (8通道 + 事件)""" data = self.getData(self.GetDataLenCount()) rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64, 65] row_to_select = np.array(rows_to_extract) if data.shape[1] > 0: return data[row_to_select, :] return np.zeros((len(rows_to_extract), 0)) def getDataViaSSVEP(self, count): """获取SSVEP数据 (8通道 + 事件)""" data = self.getData(count) rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64] row_to_select = np.array(rows_to_extract) if data.shape[1] > 0: return data[row_to_select, :] return np.zeros((len(rows_to_extract), 0)) def get_concentrateData(self, count): """获取专注力数据 (2通道)""" data = self.getData(count) rows_to_extract = [0, 1] row_to_select = np.array(rows_to_extract) if data.shape[1] > 0: return data[row_to_select, :] return np.zeros((len(rows_to_extract), 0)) def get_blinkData(self, count): """获取眨眼数据 (2通道)""" data = self.getData(count) rows_to_extract = [0, 1] row_to_select = np.array(rows_to_extract) if data.shape[1] > 0: return data[row_to_select, :] return np.zeros((len(rows_to_extract), 0)) # reset buffer def resetAllPara(self): self.nUpdate = 0 self.currentPtr = 0 self.readPtr = 0 self.buffer.fill(0.0)