范式相关代码修改完成
This commit is contained in:
240
Decoder.py
240
Decoder.py
@@ -24,7 +24,7 @@ from logs.log import algo_log
|
||||
from SSVEP.dwfbcca import FbccaDw
|
||||
# from Tools.plot_MI_EEG import plotMain
|
||||
from collections import deque
|
||||
|
||||
from Zmq.filterProcess import SlidingFilter
|
||||
|
||||
def get_root_path():
|
||||
"""
|
||||
@@ -63,6 +63,8 @@ class Decoder_main(threading.Thread):
|
||||
|
||||
self.zmqServer = zmqServer(device_info=self.device_info)
|
||||
self.zmqServer.start()
|
||||
|
||||
self.filter = SlidingFilter()
|
||||
|
||||
# self.zmqClient = zmqClient(_upper_host, _upper_port)
|
||||
# self.zmqClient.set_zmq_server(self.zmqServer)
|
||||
@@ -188,10 +190,10 @@ class Decoder_main(threading.Thread):
|
||||
self.load_model = False # 调用模型是否完成的标志
|
||||
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
||||
self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
||||
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep
|
||||
for old_pth in glob.glob(os.path.join(filePath, '*.pth')):
|
||||
os.remove(old_pth)
|
||||
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
||||
self.mp_data_queue = mp.Queue()
|
||||
self.mp_result_queue = mp.Queue()
|
||||
@@ -220,54 +222,6 @@ class Decoder_main(threading.Thread):
|
||||
# self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||
self.zmqServer.state_mode = 'rest'
|
||||
|
||||
# # 状态异常,报告上位机
|
||||
# if self.status_code != self.thread_data_server.status_code:
|
||||
# self.status_code = self.thread_data_server.status_code
|
||||
# self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||
# print('status code')
|
||||
|
||||
# # 返回电量
|
||||
# if self.energy != self.thread_data_server.energy:
|
||||
# self.energy = self.thread_data_server.energy
|
||||
# self.zmqClient.send_to_all('energy', int(self.energy))
|
||||
# print('energy')
|
||||
|
||||
# if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||
# self.thread_data_server.Impedance(True)
|
||||
# print('Impedance')
|
||||
# self.zmqServer.open_Impedance = -1
|
||||
# elif self.zmqServer.open_Impedance == False:
|
||||
# self.thread_data_server.Impedance(False)
|
||||
# self.zmqServer.open_Impedance = -1
|
||||
|
||||
# if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||
# # print(self.zmqServer.get_Impedance)
|
||||
# # print(self.thread_data_server.GetDataLenCount())
|
||||
# if self.thread_data_server.GetDataLenCount() > 250:
|
||||
# Impe_data = self.thread_data_server.getData(250)
|
||||
# # 计算阻抗
|
||||
# imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
|
||||
# self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||
# else:
|
||||
# pass
|
||||
# if self.zmqServer.getReport: #返回训练报告内容
|
||||
# self.zmqServer.getReport = False
|
||||
# allData = np.array(self.plotData)
|
||||
# allLabel = np.array(self.plotLabel) + 1
|
||||
# nTrials = min(len(allLabel),len(allData))
|
||||
# if nTrials < 30:
|
||||
# self.zmqClient.send_to_all('miReport',0)
|
||||
# else:
|
||||
# allData = allData[:nTrials]
|
||||
# allLabel = allLabel[:nTrials]
|
||||
# ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
||||
# 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
||||
# compare_names = ['C3', 'CZ', 'C4']
|
||||
# miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
|
||||
# fs=self.device_info['sample_rate'])
|
||||
# self.zmqClient.send_to_all('miReport',miReport)
|
||||
|
||||
|
||||
# --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 ---
|
||||
try:
|
||||
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||
@@ -276,17 +230,12 @@ class Decoder_main(threading.Thread):
|
||||
self.decoder_SSMVEP()
|
||||
elif self.decoder_class == 'mi':
|
||||
self.decoder_MI()
|
||||
# elif self.decoder_class == 'concentration':
|
||||
# self.decoder_concentration()
|
||||
# elif self.decoder_class == 'blink':
|
||||
# self.decoder_blink()
|
||||
# else:
|
||||
# self.
|
||||
# # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
# # if self.thread_data_server.GetDataLenCount() < 25:
|
||||
# # time.sleep(0.005)
|
||||
# # continue;
|
||||
# # self.thread_data_server.getData(25)
|
||||
else:
|
||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
continue;
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
except Exception as e:
|
||||
algo_log(f"Decoder Loop Error: {e}")
|
||||
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
||||
@@ -344,18 +293,18 @@ class Decoder_main(threading.Thread):
|
||||
if self.zmqServer.StartTrain:
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
self.zmqServer.StartTrain = False
|
||||
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.train_epoch[1] \
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
|
||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
|
||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]]
|
||||
print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||
self.trainLabel, list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
@@ -373,17 +322,17 @@ class Decoder_main(threading.Thread):
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
print('启动预测 ', formatted_time)
|
||||
|
||||
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
data = self.thread_data_server.get_SSMVEPData() # 读取全部数据
|
||||
print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx])
|
||||
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
|
||||
print('取出的: ', data.shape, 'event: ', data[-2, self.zmqServer.event_inner_idx])
|
||||
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
||||
data = data[:,
|
||||
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
pad_eeg_test = np.zeros(
|
||||
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
|
||||
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
|
||||
@@ -392,20 +341,20 @@ class Decoder_main(threading.Thread):
|
||||
choosenNum = choosenNum[0]
|
||||
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
|
||||
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
|
||||
self.zmqClient.send_to_all('result', int(choosenNum))
|
||||
self.zmqServer.broadcast_message('result', int(choosenNum))
|
||||
print('发送给界面完成。')
|
||||
else: # 休息状态
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.thread_data_server.getData(25)
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
|
||||
def decoder_MI(self):
|
||||
'''模型训练'''
|
||||
if self.train_started == False and all(
|
||||
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
|
||||
self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||
self.train_started = True
|
||||
self.trainData = np.array(self.trainData)
|
||||
self.trainLabel = np.array(self.trainLabel) + 1
|
||||
@@ -431,7 +380,7 @@ class Decoder_main(threading.Thread):
|
||||
with torch.no_grad():
|
||||
_ = self.model(warmup_data)
|
||||
self.load_model = True
|
||||
self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机
|
||||
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
|
||||
else:
|
||||
print("训练失败:", result['msg'])
|
||||
except Empty:
|
||||
@@ -444,26 +393,26 @@ class Decoder_main(threading.Thread):
|
||||
if self.zmqServer.StartTrain:
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
self.zmqServer.StartTrain = False
|
||||
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
|
||||
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
|
||||
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
||||
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.zmqServer.event_inner_idx])
|
||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||
print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
print('trial: ', self.zmqServer.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
||||
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||
list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
print('训练集:', np.shape(self.trainData))
|
||||
self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||
self.plotLabel.append(self.currentLabel)
|
||||
|
||||
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
||||
@@ -473,21 +422,21 @@ class Decoder_main(threading.Thread):
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
print('启动预测 ', formatted_time)
|
||||
|
||||
if self.thread_data_server.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.thread_data_server.event_inner_idx:
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
originalData = self.thread_data_server.get_MIData() # 读取全部数据
|
||||
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx])
|
||||
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
|
||||
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.zmqServer.event_inner_idx])
|
||||
start = time.time()
|
||||
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
||||
data = data[:,
|
||||
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
self.plotData.append(
|
||||
originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||
|
||||
test_data = data[np.newaxis, np.newaxis, :, :]
|
||||
test_data = torch.from_numpy(test_data)
|
||||
@@ -497,15 +446,15 @@ class Decoder_main(threading.Thread):
|
||||
y_pred = torch.max(Cls, 1)[1]
|
||||
self.plotLabel.append(int(y_pred.item()))
|
||||
print('运动意图识别: ', y_pred)
|
||||
self.zmqClient.send_to_all('result', int(y_pred.item()))
|
||||
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
|
||||
end = time.time()
|
||||
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
||||
else: # 休息状态
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
|
||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||
time.sleep(0.005)
|
||||
return
|
||||
self.thread_data_server.getData(25)
|
||||
self.zmqServer.paradigmBuffer.getData(25)
|
||||
|
||||
# def decoder_concentration(self):
|
||||
# if self.zmqServer.state_mode == 'predict':
|
||||
@@ -531,99 +480,6 @@ class Decoder_main(threading.Thread):
|
||||
# return
|
||||
# self.thread_data_server.getData(25)
|
||||
|
||||
#### Blink detection #####
|
||||
# def check_double_blink(self, current_time):
|
||||
# """
|
||||
# 检查是否检测到连续两次眨眼
|
||||
# @param current_time: 当前眨眼时间戳
|
||||
# @return: True表示检测到连续两次眨眼
|
||||
# """
|
||||
# if len(self.blink_timestamps) < 2:
|
||||
# return False
|
||||
|
||||
# # 检查是否在去抖期内
|
||||
# if self.last_double_blink_time > 0:
|
||||
# time_since_last_double_blink = current_time - self.last_double_blink_time
|
||||
# if time_since_last_double_blink < self.double_blink_jitter:
|
||||
# return False # 在去抖期内,忽略连续眨眼检测
|
||||
# last_time = self.blink_timestamps[-1] # 当前眨眼
|
||||
# prev_time = self.blink_timestamps[-2] # 上次眨眼
|
||||
|
||||
# interval = last_time - prev_time
|
||||
# if interval <= self.double_blink_interval:
|
||||
# return True
|
||||
|
||||
# return False
|
||||
|
||||
# def process_blink_detection(self):
|
||||
# """
|
||||
# 在缓冲区数据上执行,单次眨眼检测
|
||||
# """
|
||||
# if len(self.fp1_buffer) < self.window_samples:
|
||||
# return
|
||||
|
||||
# fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
|
||||
# fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
|
||||
# # 计算FP1和FP2的平均
|
||||
# fp12_mean = (fp1_data + fp2_data) / 2.0
|
||||
# # 带通滤波
|
||||
# try:
|
||||
# fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
|
||||
# except Exception as e:
|
||||
# print(f"Filter error: {e}")
|
||||
# return
|
||||
# F = np.diff(fp12_filtered)
|
||||
# if len(F) < 3:
|
||||
# return
|
||||
# b, d, e = blink_detection(F, self.device_info['sample_rate'], self.Dmin, self.Dmax, self.EMin, self.EMax)
|
||||
|
||||
# if b == 1:
|
||||
# samples_since_last = self.total_samples - self.last_blink_time
|
||||
# time_since_last_ms = (samples_since_last / self.device_info['sample_rate']) * 1000
|
||||
# if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
|
||||
# self.blink_count += 1
|
||||
# self.last_blink_time = self.total_samples
|
||||
# current_time = time.time()
|
||||
# self.blink_timestamps.append(current_time)
|
||||
# blink_event = {
|
||||
# 'count': self.blink_count,
|
||||
# 'time': current_time,
|
||||
# 'sample_index': self.total_samples,
|
||||
# 'duration_ms': d,
|
||||
# 'energy': e
|
||||
# }
|
||||
# self.blink_events.append(blink_event)
|
||||
# self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
|
||||
# if self.check_double_blink(current_time):
|
||||
# self.double_blink_count += 1
|
||||
# interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
|
||||
# double_blink_event = {
|
||||
# 'double_blink_count': self.double_blink_count,
|
||||
# 'blink1_time': self.blink_timestamps[-2],
|
||||
# 'blink2_time': self.blink_timestamps[-1],
|
||||
# 'interval': interval
|
||||
# }
|
||||
# self.double_blink_events.append(double_blink_event)
|
||||
# self.last_double_blink_time = current_time
|
||||
# self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
|
||||
|
||||
# def decoder_blink(self):
|
||||
# if self.thread_data_server.GetDataLenCount() < 50:
|
||||
# time.sleep(0.005)
|
||||
# return
|
||||
# if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
# data = self.thread_data_server.get_blinkData(50)
|
||||
# fp1_data = data[0, :] # ch1 (相当于FP1)
|
||||
# fp2_data = data[1, :] # ch2 (相当于FP2)
|
||||
# for i in range(len(fp1_data)):
|
||||
# self.fp1_buffer.append(fp1_data[i])
|
||||
# self.fp2_buffer.append(fp2_data[i])
|
||||
# self.total_samples += 1
|
||||
# self.sample_counter += 1
|
||||
|
||||
# if self.sample_counter >= self.step_samples:
|
||||
# self.process_blink_detection()
|
||||
# self.sample_counter = 0
|
||||
|
||||
def stop(self):
|
||||
'''
|
||||
|
||||
Reference in New Issue
Block a user