范式相关代码修改完成
This commit is contained in:
@@ -11,7 +11,7 @@ class ParadigmRingBuffer:
|
||||
def __init__(self, n_chan, n_points):
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
self.buffer = np.zeros((n_chan, n_points))
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0
|
||||
self.nUpdate = 0
|
||||
@@ -117,6 +117,6 @@ class ParadigmRingBuffer:
|
||||
def resetAllPara(self):
|
||||
self.nUpdate = 0
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||
self.readPtr = 0
|
||||
self.buffer.fill(0.0)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
import numpy as np
|
||||
import threading
|
||||
from scipy import signal
|
||||
from logs.log import algo_log
|
||||
|
||||
class FilterRingBuffer:
|
||||
@@ -16,7 +17,7 @@ class FilterRingBuffer:
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
|
||||
self.current_ptr = 0 # 写入指针
|
||||
self.total_samples = 0 # 已写入总点数
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ import threading
|
||||
import json
|
||||
import queue
|
||||
from typing import Dict
|
||||
import datetime
|
||||
import time
|
||||
|
||||
# from Device.SunnyLinker import SunnyLinker64
|
||||
from Zmq.dataBuffer import ParadigmRingBuffer
|
||||
from Zmq.filterProcess import FilterRingBuffer
|
||||
@@ -74,43 +77,16 @@ class zmqServer(threading.Thread):
|
||||
|
||||
# 范式buffer参数, 事件检测相关
|
||||
self._event_lock = threading.Lock()
|
||||
self._epoch_finished = False
|
||||
self._event_inner_idx = -1
|
||||
self.pack_contain_event = False
|
||||
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.count_events = {}
|
||||
self.latency = 50
|
||||
self.train_latency = 50
|
||||
self._interval_inited = False
|
||||
|
||||
@property
|
||||
def interval_inited(self):
|
||||
return self._interval_inited
|
||||
|
||||
@interval_inited.setter
|
||||
def interval_inited(self, value):
|
||||
self._interval_inited = value
|
||||
|
||||
@property
|
||||
def epoch_finished(self):
|
||||
with self._event_lock:
|
||||
return self._epoch_finished
|
||||
|
||||
@epoch_finished.setter
|
||||
def epoch_finished(self, value):
|
||||
with self._event_lock:
|
||||
self._epoch_finished = value
|
||||
|
||||
@property
|
||||
def event_inner_idx(self):
|
||||
with self._event_lock:
|
||||
return self._event_inner_idx
|
||||
|
||||
@event_inner_idx.setter
|
||||
def event_inner_idx(self, value):
|
||||
with self._event_lock:
|
||||
self._event_inner_idx = value
|
||||
self.count_events = {}
|
||||
self.epoch_finished = False
|
||||
self.pack_contain_event = False
|
||||
self.event_inner_idx = -1
|
||||
self.interval_inited = False
|
||||
|
||||
def reset_state(self):
|
||||
"""清空采集器状态和缓存数据"""
|
||||
@@ -148,10 +124,6 @@ class zmqServer(threading.Thread):
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.interval_inited = True
|
||||
# if getattr(self, 'serial', None) and self.serial.is_open:
|
||||
# self.serial.close()
|
||||
# self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口
|
||||
|
||||
|
||||
def broadcast_message(self, method, params):
|
||||
"""Put message into queue to be sent to all command clients"""
|
||||
@@ -262,9 +234,25 @@ class zmqServer(threading.Thread):
|
||||
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
|
||||
data_np = data_np.T.astype(np.float64)
|
||||
|
||||
# 6. 写入缓冲区
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
# 6. 写入滤波缓冲区
|
||||
self.filterBuffer.appendBuffer(data_np)
|
||||
|
||||
# 7. 写入范式缓冲区
|
||||
try:
|
||||
with self.paradigmBufferLock:
|
||||
if self.interval_inited:
|
||||
self.epoch_finished = self.detect_event(data_np)
|
||||
if self.pack_contain_event:
|
||||
self.paradigmBuffer.resetAllPara() # 检测到当前pack含有event,清除ringbuffer中之前的数据
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
if self.epoch_finished:
|
||||
time.sleep(0.005)
|
||||
algo_log('epoch_finished: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG")
|
||||
else:
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
except Exception as e:
|
||||
print("锁:写入异常",e)
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
|
||||
# algo_log(f"数据写入成功:shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG")
|
||||
|
||||
@@ -274,6 +262,38 @@ class zmqServer(threading.Thread):
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# 检测是否含有标签
|
||||
def detect_event(self, samples):
|
||||
self.pack_contain_event = False
|
||||
events = np.array(samples[-2])[0].tolist()
|
||||
for idx, event in enumerate(events):
|
||||
if int(event) in self.events:
|
||||
new_key = "".join(
|
||||
[
|
||||
str(event),
|
||||
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||
-%H-%M-%S"),
|
||||
]
|
||||
)
|
||||
if event == self.predict_event:
|
||||
self.count_events[new_key] = self.latency + 1
|
||||
else:
|
||||
self.count_events[new_key] = self.train_latency + 1
|
||||
self.event_inner_idx = idx
|
||||
self.pack_contain_event = True
|
||||
drop_items = []
|
||||
for key, value in self.count_events.items():
|
||||
value = value - 1
|
||||
if value == 0:
|
||||
drop_items.append(key)
|
||||
self.count_events[key] = value
|
||||
for key in drop_items:
|
||||
del self.count_events[key]
|
||||
if drop_items:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def _process_send_queue(self):
|
||||
"""处理发送队列,向所有命令客户端广播消息"""
|
||||
|
||||
Reference in New Issue
Block a user