范式相关代码修改完成
This commit is contained in:
240
Decoder.py
240
Decoder.py
@@ -24,7 +24,7 @@ from logs.log import algo_log
|
||||
from SSVEP.dwfbcca import FbccaDw
|
||||
# from Tools.plot_MI_EEG import plotMain
|
||||
from collections import deque
|
||||
|
||||
from Zmq.filterProcess import SlidingFilter
|
||||
|
||||
def get_root_path():
|
||||
"""
|
||||
@@ -64,6 +64,8 @@ class Decoder_main(threading.Thread):
|
||||
self.zmqServer = zmqServer(device_info=self.device_info)
|
||||
self.zmqServer.start()
|
||||
|
||||
self.filter = SlidingFilter()
|
||||
|
||||
# self.zmqClient = zmqClient(_upper_host, _upper_port)
|
||||
# self.zmqClient.set_zmq_server(self.zmqServer)
|
||||
# self.zmqClient.connect()
|
||||
@@ -188,10 +190,10 @@ class Decoder_main(threading.Thread):
|
||||
self.load_model = False # 调用模型是否完成的标志
|
||||
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
||||
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
||||
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
|
||||
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
|
||||
os.remove(old_pth)
|
||||
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
||||
self.mp_data_queue = mp.Queue()
|
||||
self.mp_result_queue = mp.Queue()
|
||||
@@ -220,54 +222,6 @@ class Decoder_main(threading.Thread):
|
||||
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||
self.zmqServer.state_mode = 'rest'
|
||||
|
||||
# # 状态异常,报告上位机
|
||||
# if self.status_code != self.thread_data_server.status_code:
|
||||
# self.status_code = self.thread_data_server.status_code
|
||||
# self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||
# print('status code')
|
||||
|
||||
# # 返回电量
|
||||
# if self.energy != self.thread_data_server.energy:
|
||||
# self.energy = self.thread_data_server.energy
|
||||
# self.zmqClient.send_to_all('energy', int(self.energy))
|
||||
# print('energy')
|
||||
|
||||
# if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||
# self.thread_data_server.Impedance(True)
|
||||
# print('Impedance')
|
||||
# self.zmqServer.open_Impedance = -1
|
||||
# elif self.zmqServer.open_Impedance == False:
|
||||
# self.thread_data_server.Impedance(False)
|
||||
# self.zmqServer.open_Impedance = -1
|
||||
|
||||
# if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||
# # print(self.zmqServer.get_Impedance)
|
||||
# # print(self.thread_data_server.GetDataLenCount())
|
||||
# if self.thread_data_server.GetDataLenCount() > 250:
|
||||
# Impe_data = self.thread_data_server.getData(250)
|
||||
# # 计算阻抗
|
||||
# imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
|
||||
# self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||
# else:
|
||||
# pass
|
||||
# if self.zmqServer.getReport: #返回训练报告内容
|
||||
# self.zmqServer.getReport = False
|
||||
# allData = np.array(self.plotData)
|
||||
# allLabel = np.array(self.plotLabel) + 1
|
||||
# nTrials = min(len(allLabel),len(allData))
|
||||
# if nTrials < 30:
|
||||
# self.zmqClient.send_to_all('miReport',0)
|
||||
# else:
|
||||
# allData = allData[:nTrials]
|
||||
# allLabel = allLabel[:nTrials]
|
||||
# ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
||||
# 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
||||
# compare_names = ['C3', 'CZ', 'C4']
|
||||
# miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
|
||||
# fs=self.device_info['sample_rate'])
|
||||
# self.zmqClient.send_to_all('miReport',miReport)
|
||||
|
||||
|
||||
# --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 ---
|
||||
try:
|
||||
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||
@@ -276,17 +230,12 @@ class Decoder_main(threading.Thread):
|
||||
self.decoder_SSMVEP()
|
||||
elif self.decoder_class == 'mi':
|
||||
self.decoder_MI()
|
||||
# elif self.decoder_class == 'concentration':
|
||||
# self.decoder_concentration()
|
||||
# elif self.decoder_class == 'blink':
|
||||
# self.decoder_blink()
|
||||
# else:
|
||||
# self.
|
||||
# # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
# # if self.thread_data_server.GetDataLenCount() < 25:
|
||||
# # time.sleep(0.005)
|
||||
# # continue;
|
||||
# # self.thread_data_server.getData(25)
|
||||
else:
|
||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
continue;
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
except Exception as e:
|
||||
algo_log(f"Decoder Loop Error: {e}")
|
||||
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
||||
@@ -344,18 +293,18 @@ class Decoder_main(threading.Thread):
|
||||
if self.zmqServer.StartTrain:
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
self.zmqServer.StartTrain = False
|
||||
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.train_epoch[1] \
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
|
||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
|
||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]]
|
||||
print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||
self.trainLabel, list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
@@ -373,17 +322,17 @@ class Decoder_main(threading.Thread):
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
print('启动预测 ', formatted_time)
|
||||
|
||||
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
data = self.thread_data_server.get_SSMVEPData() # 读取全部数据
|
||||
print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx])
|
||||
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
|
||||
print('取出的: ', data.shape, 'event: ', data[-2, self.zmqServer.event_inner_idx])
|
||||
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
||||
data = data[:,
|
||||
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
pad_eeg_test = np.zeros(
|
||||
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
|
||||
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
|
||||
@@ -392,20 +341,20 @@ class Decoder_main(threading.Thread):
|
||||
choosenNum = choosenNum[0]
|
||||
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
|
||||
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
|
||||
self.zmqClient.send_to_all('result', int(choosenNum))
|
||||
self.zmqServer.broadcast_message('result', int(choosenNum))
|
||||
print('发送给界面完成。')
|
||||
else: # 休息状态
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.thread_data_server.getData(25)
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
|
||||
def decoder_MI(self):
|
||||
'''模型训练'''
|
||||
if self.train_started == False and all(
|
||||
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
|
||||
self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||
self.train_started = True
|
||||
self.trainData = np.array(self.trainData)
|
||||
self.trainLabel = np.array(self.trainLabel) + 1
|
||||
@@ -431,7 +380,7 @@ class Decoder_main(threading.Thread):
|
||||
with torch.no_grad():
|
||||
_ = self.model(warmup_data)
|
||||
self.load_model = True
|
||||
self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机
|
||||
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
|
||||
else:
|
||||
print("训练失败:", result['msg'])
|
||||
except Empty:
|
||||
@@ -444,26 +393,26 @@ class Decoder_main(threading.Thread):
|
||||
if self.zmqServer.StartTrain:
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
self.zmqServer.StartTrain = False
|
||||
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
|
||||
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
|
||||
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
||||
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.zmqServer.event_inner_idx])
|
||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||
print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
print('trial: ', self.zmqServer.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
||||
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||
list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
print('训练集:', np.shape(self.trainData))
|
||||
self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||
self.plotLabel.append(self.currentLabel)
|
||||
|
||||
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
||||
@@ -473,21 +422,21 @@ class Decoder_main(threading.Thread):
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
print('启动预测 ', formatted_time)
|
||||
|
||||
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
originalData = self.thread_data_server.get_MIData() # 读取全部数据
|
||||
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx])
|
||||
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
|
||||
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.zmqServer.event_inner_idx])
|
||||
start = time.time()
|
||||
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
||||
data = data[:,
|
||||
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
self.plotData.append(
|
||||
originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||
|
||||
test_data = data[np.newaxis, np.newaxis, :, :]
|
||||
test_data = torch.from_numpy(test_data)
|
||||
@@ -497,15 +446,15 @@ class Decoder_main(threading.Thread):
|
||||
y_pred = torch.max(Cls, 1)[1]
|
||||
self.plotLabel.append(int(y_pred.item()))
|
||||
print('运动意图识别: ', y_pred)
|
||||
self.zmqClient.send_to_all('result', int(y_pred.item()))
|
||||
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
|
||||
end = time.time()
|
||||
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
||||
else: # 休息状态
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.thread_data_server.getData(25)
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
|
||||
# def decoder_concentration(self):
|
||||
# if self.zmqServer.state_mode == 'predict':
|
||||
@@ -531,99 +480,6 @@ class Decoder_main(threading.Thread):
|
||||
# return
|
||||
# self.thread_data_server.getData(25)
|
||||
|
||||
#### Blink detection #####
|
||||
# def check_double_blink(self, current_time):
|
||||
# """
|
||||
# 检查是否检测到连续两次眨眼
|
||||
# @param current_time: 当前眨眼时间戳
|
||||
# @return: True表示检测到连续两次眨眼
|
||||
# """
|
||||
# if len(self.blink_timestamps) < 2:
|
||||
# return False
|
||||
|
||||
# # 检查是否在去抖期内
|
||||
# if self.last_double_blink_time > 0:
|
||||
# time_since_last_double_blink = current_time - self.last_double_blink_time
|
||||
# if time_since_last_double_blink < self.double_blink_jitter:
|
||||
# return False # 在去抖期内,忽略连续眨眼检测
|
||||
# last_time = self.blink_timestamps[-1] # 当前眨眼
|
||||
# prev_time = self.blink_timestamps[-2] # 上次眨眼
|
||||
|
||||
# interval = last_time - prev_time
|
||||
# if interval <= self.double_blink_interval:
|
||||
# return True
|
||||
|
||||
# return False
|
||||
|
||||
# def process_blink_detection(self):
|
||||
# """
|
||||
# 在缓冲区数据上执行,单次眨眼检测
|
||||
# """
|
||||
# if len(self.fp1_buffer) < self.window_samples:
|
||||
# return
|
||||
|
||||
# fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
|
||||
# fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
|
||||
# # 计算FP1和FP2的平均
|
||||
# fp12_mean = (fp1_data + fp2_data) / 2.0
|
||||
# # 带通滤波
|
||||
# try:
|
||||
# fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
|
||||
# except Exception as e:
|
||||
# print(f"Filter error: {e}")
|
||||
# return
|
||||
# F = np.diff(fp12_filtered)
|
||||
# if len(F) < 3:
|
||||
# return
|
||||
# b, d, e = blink_detection(F, self.device_info['sample_rate'], self.Dmin, self.Dmax, self.EMin, self.EMax)
|
||||
|
||||
# if b == 1:
|
||||
# samples_since_last = self.total_samples - self.last_blink_time
|
||||
# time_since_last_ms = (samples_since_last / self.device_info['sample_rate']) * 1000
|
||||
# if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
|
||||
# self.blink_count += 1
|
||||
# self.last_blink_time = self.total_samples
|
||||
# current_time = time.time()
|
||||
# self.blink_timestamps.append(current_time)
|
||||
# blink_event = {
|
||||
# 'count': self.blink_count,
|
||||
# 'time': current_time,
|
||||
# 'sample_index': self.total_samples,
|
||||
# 'duration_ms': d,
|
||||
# 'energy': e
|
||||
# }
|
||||
# self.blink_events.append(blink_event)
|
||||
# self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
|
||||
# if self.check_double_blink(current_time):
|
||||
# self.double_blink_count += 1
|
||||
# interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
|
||||
# double_blink_event = {
|
||||
# 'double_blink_count': self.double_blink_count,
|
||||
# 'blink1_time': self.blink_timestamps[-2],
|
||||
# 'blink2_time': self.blink_timestamps[-1],
|
||||
# 'interval': interval
|
||||
# }
|
||||
# self.double_blink_events.append(double_blink_event)
|
||||
# self.last_double_blink_time = current_time
|
||||
# self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
|
||||
|
||||
# def decoder_blink(self):
|
||||
# if self.thread_data_server.GetDataLenCount() < 50:
|
||||
# time.sleep(0.005)
|
||||
# return
|
||||
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
# data = self.thread_data_server.get_blinkData(50)
|
||||
# fp1_data = data[0, :] # ch1 (相当于FP1)
|
||||
# fp2_data = data[1, :] # ch2 (相当于FP2)
|
||||
# for i in range(len(fp1_data)):
|
||||
# self.fp1_buffer.append(fp1_data[i])
|
||||
# self.fp2_buffer.append(fp2_data[i])
|
||||
# self.total_samples += 1
|
||||
# self.sample_counter += 1
|
||||
|
||||
# if self.sample_counter >= self.step_samples:
|
||||
# self.process_blink_detection()
|
||||
# self.sample_counter = 0
|
||||
|
||||
def stop(self):
|
||||
'''
|
||||
|
||||
11
README.md
11
README.md
@@ -17,5 +17,12 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and
|
||||
2026年6月5日13:55:34
|
||||
|
||||
# 遗留问题
|
||||
1. 之前当处于阻抗检测状态时,Decoder在空跑。
|
||||
2. 当前无法判断是否处于阻抗检测状态。
|
||||
1. 之前当处于阻抗检测状态时,Decoder在空跑。当前无法判断是否处于阻抗检测状态。
|
||||
- 解决方法,保留之前发阻抗命令
|
||||
|
||||
|
||||
# 常用命令
|
||||
source activate 3in1Py310
|
||||
python runDecoder.py
|
||||
python datamock.py
|
||||
python ZeroMQClient_mock.py
|
||||
|
||||
@@ -11,7 +11,7 @@ class ParadigmRingBuffer:
|
||||
def __init__(self, n_chan, n_points):
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
self.buffer = np.zeros((n_chan, n_points))
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0
|
||||
self.nUpdate = 0
|
||||
@@ -117,6 +117,6 @@ class ParadigmRingBuffer:
|
||||
def resetAllPara(self):
|
||||
self.nUpdate = 0
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||
self.readPtr = 0
|
||||
self.buffer.fill(0.0)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
import numpy as np
|
||||
import threading
|
||||
from scipy import signal
|
||||
from logs.log import algo_log
|
||||
|
||||
class FilterRingBuffer:
|
||||
@@ -16,7 +17,7 @@ class FilterRingBuffer:
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float64)
|
||||
self.buffer = np.zeros((n_chan, n_points), dtype=np.float32)
|
||||
self.current_ptr = 0 # 写入指针
|
||||
self.total_samples = 0 # 已写入总点数
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ import threading
|
||||
import json
|
||||
import queue
|
||||
from typing import Dict
|
||||
import datetime
|
||||
import time
|
||||
|
||||
# from Device.SunnyLinker import SunnyLinker64
|
||||
from Zmq.dataBuffer import ParadigmRingBuffer
|
||||
from Zmq.filterProcess import FilterRingBuffer
|
||||
@@ -74,43 +77,16 @@ class zmqServer(threading.Thread):
|
||||
|
||||
# 范式buffer参数, 事件检测相关
|
||||
self._event_lock = threading.Lock()
|
||||
self._epoch_finished = False
|
||||
self._event_inner_idx = -1
|
||||
self.pack_contain_event = False
|
||||
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.count_events = {}
|
||||
self.latency = 50
|
||||
self.train_latency = 50
|
||||
self._interval_inited = False
|
||||
|
||||
@property
|
||||
def interval_inited(self):
|
||||
return self._interval_inited
|
||||
|
||||
@interval_inited.setter
|
||||
def interval_inited(self, value):
|
||||
self._interval_inited = value
|
||||
|
||||
@property
|
||||
def epoch_finished(self):
|
||||
with self._event_lock:
|
||||
return self._epoch_finished
|
||||
|
||||
@epoch_finished.setter
|
||||
def epoch_finished(self, value):
|
||||
with self._event_lock:
|
||||
self._epoch_finished = value
|
||||
|
||||
@property
|
||||
def event_inner_idx(self):
|
||||
with self._event_lock:
|
||||
return self._event_inner_idx
|
||||
|
||||
@event_inner_idx.setter
|
||||
def event_inner_idx(self, value):
|
||||
with self._event_lock:
|
||||
self._event_inner_idx = value
|
||||
self.count_events = {}
|
||||
self.epoch_finished = False
|
||||
self.pack_contain_event = False
|
||||
self.event_inner_idx = -1
|
||||
self.interval_inited = False
|
||||
|
||||
def reset_state(self):
|
||||
"""清空采集器状态和缓存数据"""
|
||||
@@ -148,10 +124,6 @@ class zmqServer(threading.Thread):
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.interval_inited = True
|
||||
# if getattr(self, 'serial', None) and self.serial.is_open:
|
||||
# self.serial.close()
|
||||
# self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口
|
||||
|
||||
|
||||
def broadcast_message(self, method, params):
|
||||
"""Put message into queue to be sent to all command clients"""
|
||||
@@ -262,10 +234,26 @@ class zmqServer(threading.Thread):
|
||||
data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums'])
|
||||
data_np = data_np.T.astype(np.float64)
|
||||
|
||||
# 6. 写入缓冲区
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
# 6. 写入滤波缓冲区
|
||||
self.filterBuffer.appendBuffer(data_np)
|
||||
|
||||
# 7. 写入范式缓冲区
|
||||
try:
|
||||
with self.paradigmBufferLock:
|
||||
if self.interval_inited:
|
||||
self.epoch_finished = self.detect_event(data_np)
|
||||
if self.pack_contain_event:
|
||||
self.paradigmBuffer.resetAllPara() # 检测到当前pack含有event,清除ringbuffer中之前的数据
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
if self.epoch_finished:
|
||||
time.sleep(0.005)
|
||||
algo_log('epoch_finished: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG")
|
||||
else:
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
except Exception as e:
|
||||
print("锁:写入异常",e)
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
|
||||
# algo_log(f"数据写入成功:shape={data_np.shape}, 范围=[{data_np.min():.2f}, {data_np.max():.2f}] μV", level="DEBUG")
|
||||
|
||||
except Exception as e:
|
||||
@@ -274,6 +262,38 @@ class zmqServer(threading.Thread):
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# 检测是否含有标签
|
||||
def detect_event(self, samples):
|
||||
self.pack_contain_event = False
|
||||
events = np.array(samples[-2])[0].tolist()
|
||||
for idx, event in enumerate(events):
|
||||
if int(event) in self.events:
|
||||
new_key = "".join(
|
||||
[
|
||||
str(event),
|
||||
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||
-%H-%M-%S"),
|
||||
]
|
||||
)
|
||||
if event == self.predict_event:
|
||||
self.count_events[new_key] = self.latency + 1
|
||||
else:
|
||||
self.count_events[new_key] = self.train_latency + 1
|
||||
self.event_inner_idx = idx
|
||||
self.pack_contain_event = True
|
||||
drop_items = []
|
||||
for key, value in self.count_events.items():
|
||||
value = value - 1
|
||||
if value == 0:
|
||||
drop_items.append(key)
|
||||
self.count_events[key] = value
|
||||
for key in drop_items:
|
||||
del self.count_events[key]
|
||||
if drop_items:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def _process_send_queue(self):
|
||||
"""处理发送队列,向所有命令客户端广播消息"""
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user