This commit is contained in:
2026-06-09 14:23:25 +08:00
parent 07560304ca
commit 9f034d1105
5 changed files with 82 additions and 70 deletions

View File

@@ -26,6 +26,8 @@ from SSVEP.dwfbcca import FbccaDw
from collections import deque from collections import deque
from Zmq.filterProcess import SlidingFilter from Zmq.filterProcess import SlidingFilter
save_train_data = int(IniRead('system', 'save_train_data', 0))
def get_root_path(): def get_root_path():
""" """
Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录) Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录)
@@ -209,7 +211,6 @@ class Decoder_main(threading.Thread):
elif self.decoder_class == 'mi': elif self.decoder_class == 'mi':
self.decoder_MI() self.decoder_MI()
else: else:
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005) time.sleep(0.005)
continue; continue;
@@ -223,31 +224,32 @@ class Decoder_main(threading.Thread):
self.zmqServer.StartDecode = False self.zmqServer.StartDecode = False
self.decodingSteps = 1 self.decodingSteps = 1
self.zmqServer.paradigmBuffer.resetAllPara() self.zmqServer.paradigmBuffer.resetAllPara()
print('启动预测') algo_log('启动SSVEP预测', level="DEBUG")
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50:
time.sleep(0.005) time.sleep(0.005)
return return
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码 if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
return return
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50) data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
algo_log(f"SSVEP取出的{data.shape}, data = {data[:20]}", level="DEBUG")
data = data[:self.n_chan, :] data = data[:self.n_chan, :]
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热 if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
self.dw.warmFilter(data) # 预热 self.dw.warmFilter(data) # 预热
self.decodingSteps = 2 self.decodingSteps = 2
print('预热数据完成。开始预测') algo_log('SSVEP预热数据完成。开始预测', level="DEBUG")
return return
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中 if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount) choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
self.calculateCount += 1 self.calculateCount += 1
if choosenNum != -1 and self.is_valid_signal(data): if choosenNum != -1 and self.is_valid_signal(data):
self.decodingSteps = 3 self.decodingSteps = 3
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount)) algo_log('SSVEP预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount), level="DEBUG")
self.calculateCount = 0 self.calculateCount = 0
if self.decodingSteps == 3: # 发送解码后的信息 if self.decodingSteps == 3: # 发送解码后的信息
self.zmqServer.broadcast_message('result', int(choosenNum)) self.zmqServer.broadcast_message('result', int(choosenNum))
self.decodingSteps = 0 self.decodingSteps = 0
print('发送给界面完成。') algo_log('SSVEP发送给界面完成。', level="DEBUG")
def decoder_SSMVEP(self): def decoder_SSMVEP(self):
'''模型训练''' '''模型训练'''
@@ -255,34 +257,28 @@ class Decoder_main(threading.Thread):
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成 self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
self.trainData = np.array(self.trainData) self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel) self.trainLabel = np.array(self.trainLabel)
print(np.shape(self.trainData), (self.trainLabel)) algo_log(f"开始SSMVEP模型训练数据形状{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG")
# 保存多个数组到文件 if save_train_data == 1:
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel) now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) save_path = f"{now_str}.npz"
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
now = datetime.now() now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('模型训练完成', formatted_time) algo_log(f"SSMVEP模型训练完成时间{formatted_time}", level="DEBUG")
self.load_model = True self.load_model = True
self.zmqServer.broadcast_message('paradigm', 1) self.zmqServer.broadcast_message('paradigm', 1)
'''训练阶段采集数据''' '''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态 if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \ if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.train_epoch[1] + self.zmqServer.event_inner_idx: self.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel self.currentLabel = self.zmqServer.currentLabel
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据 trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[ trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]] 0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \ self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train: and self.trainLabel.count(self.currentLabel) < self.single_train:
@@ -301,15 +297,14 @@ class Decoder_main(threading.Thread):
self.zmqServer.StartDecode = False self.zmqServer.StartDecode = False
now = datetime.now() now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time) algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \ self.interval_epoch[1] \
+ self.zmqServer.event_inner_idx: + self.zmqServer.event_inner_idx:
time.sleep(0.0001) time.sleep(0.0001)
return return
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据 data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
print('取出的: ', data.shape, 'event: ', data[-2, self.zmqServer.event_inner_idx]) algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
data = self.preprocess(data[:self.n_chan, :]) # 预处理 data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:, data = data[:,
self.zmqServer.event_inner_idx + self.interval_epoch[ self.zmqServer.event_inner_idx + self.interval_epoch[
@@ -320,12 +315,10 @@ class Decoder_main(threading.Thread):
choosenNum, features_2 = self.decoder.predict(pad_eeg_test) choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
if isinstance(choosenNum, np.ndarray): if isinstance(choosenNum, np.ndarray):
choosenNum = choosenNum[0] choosenNum = choosenNum[0]
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]), algo_log(f"结果:{choosenNum}, rho: {sorted(features_2[0])[-1] - sorted(features_2[0])[-2]}", level="DEBUG")
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
self.zmqServer.broadcast_message('result', int(choosenNum)) self.zmqServer.broadcast_message('result', int(choosenNum))
print('发送给界面完成。') algo_log("SSMVEP发送给界面完成。", level="DEBUG")
else: # 休息状态 else: # 休息状态
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005) time.sleep(0.005)
return return
@@ -339,7 +332,11 @@ class Decoder_main(threading.Thread):
self.train_started = True self.train_started = True
self.trainData = np.array(self.trainData) self.trainData = np.array(self.trainData)
self.trainLabel = np.array(self.trainLabel) + 1 self.trainLabel = np.array(self.trainLabel) + 1
# print('训练集:',np.shape(self.trainData), (self.trainLabel)) algo_log(f"MI开始训练训练集{np.shape(self.trainData)}标签shape{np.shape(self.trainLabel)}", level="DEBUG")
if save_train_data == 1:
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f"{now_str}.npz"
np.savez(save_path, array1=self.trainData, array2=self.trainLabel)
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型 p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
p.start() p.start()
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath, self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
@@ -350,7 +347,7 @@ class Decoder_main(threading.Thread):
try: try:
result = self.mp_result_queue.get_nowait() result = self.mp_result_queue.get_nowait()
if result['status'] == 'success': if result['status'] == 'success':
print("模型训练完成,加载新模型") algo_log("MI模型训练完成,加载新模型", level="DEBUG")
# 调用模型 # 调用模型
self.model = torch.load(self.modelPath, weights_only=False) self.model = torch.load(self.modelPath, weights_only=False)
self.model.eval() self.model.eval()
@@ -363,45 +360,42 @@ class Decoder_main(threading.Thread):
self.load_model = True self.load_model = True
self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机 self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机
else: else:
print("训练失败:", result['msg']) algo_log("MI训练失败: " + result['msg'], level="DEBUG")
except Empty: except Empty:
pass # 还没完成 pass # 还没完成
except Exception as e: except Exception as e:
print('模型调用失败: ', e) algo_log("MI模型训练失败: " + str(e), level="DEBUG")
'''训练阶段采集数据''' '''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态 if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
if self.zmqServer.StartTrain: if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.currentLabel = self.zmqServer.currentLabel self.interval_epoch[1] + self.zmqServer.event_inner_idx:
self.zmqServer.StartTrain = False algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.zmqServer.event_inner_idx:
time.sleep(0.0001)
return
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据 originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.zmqServer.event_inner_idx]) algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[ trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
print('trial: ', self.zmqServer.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1]) algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel, if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \ list) \
and self.trainLabel.count(self.currentLabel) < self.single_train: and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial) self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel) self.trainLabel.append(self.currentLabel)
print('训练集:', np.shape(self.trainData)) algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[ self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]) 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel) self.plotLabel.append(self.currentLabel)
else:
time.sleep(0.0001)
return
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态 elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode: if self.zmqServer.StartDecode:
self.zmqServer.StartDecode = False self.zmqServer.StartDecode = False
now = datetime.now() now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
print('启动预测 ', formatted_time) algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \ self.interval_epoch[1] \
@@ -409,7 +403,7 @@ class Decoder_main(threading.Thread):
time.sleep(0.0001) time.sleep(0.0001)
return return
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据 originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.zmqServer.event_inner_idx]) algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
start = time.time() start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理 data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:, data = data[:,
@@ -426,12 +420,11 @@ class Decoder_main(threading.Thread):
Cls = self.model(test_data) Cls = self.model(test_data)
y_pred = torch.max(Cls, 1)[1] y_pred = torch.max(Cls, 1)[1]
self.plotLabel.append(int(y_pred.item())) self.plotLabel.append(int(y_pred.item()))
print('运动意图识别: ', y_pred) algo_log(f"MI运动意图识别: {y_pred}")
self.zmqServer.broadcast_message('paradigm', int(y_pred.item())) self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
end = time.time() end = time.time()
print(f'发送给界面完成,耗时{end - start:.3f}s。') print(f'发送给界面完成,耗时{end - start:.3f}s。')
else: # 休息状态 else: # 休息状态
if self.zmqServer.open_Impedance == False: # 非阻抗检测状态
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
time.sleep(0.005) time.sleep(0.005)
return return

View File

@@ -20,3 +20,7 @@ python runDecoder.py
python datamock.py python datamock.py
python ZeroMQClient_mock.py python ZeroMQClient_mock.py
python system_test.py python system_test.py
# 遗留问题
1. mvep是否要把list freq 开放到config

View File

@@ -250,6 +250,20 @@ class zmqServer(threading.Thread):
self.decoder_switch = True self.decoder_switch = True
elif method == "train": elif method == "train":
self.state_mode = 'train' self.state_mode = 'train'
resp = {
"method": "train_response",
"params": {
"code": 200,
"message": "ok"
}
}
try:
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
algo_log(f"train 命令已即时回复客户端 {ident}", level="DEBUG")
except Exception as e:
algo_log(f"train 命令回复失败: {e}", level="ERROR")
return
elif method == "predict": elif method == "predict":
self.state_mode = 'predict' self.state_mode = 'predict'
if params == 1: #开始解码 if params == 1: #开始解码
@@ -360,7 +374,7 @@ class zmqServer(threading.Thread):
# -------------------------- 主循环 -------------------------- # -------------------------- 主循环 --------------------------
def run(self): def run(self):
self.running = True self.running = True
algo_log(f"ZMQ服务器启动成功 - 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO") algo_log(f"ZMQ服务器启动成功 - host: {self.host}, 命令端口: {self.cmd_port}, 数据端口: {self.data_port}", level="INFO")
try: try:
while self.running: while self.running:

View File

@@ -18,6 +18,7 @@ Upper_Port = 8088
Serial_port = COM44 Serial_port = COM44
algo_log_level = DEBUG algo_log_level = DEBUG
console_output = 1 console_output = 1
save_train_data = 0
; 64 导设备配置 ; 64 导设备配置
[device_type_1] [device_type_1]

View File

@@ -12,7 +12,7 @@ EEG_FREQ = 10 # EEG 正弦波频率 Hz
EEG_AMP = 100.0 # EEG 幅值 100μV EEG_AMP = 100.0 # EEG 幅值 100μV
LABEL_INTERVAL = 5 # 标签间隔秒数 LABEL_INTERVAL = 5 # 标签间隔秒数
# SERVER_ADDR = 'tcp://127.0.0.1:8100' # SERVER_ADDR = 'tcp://127.0.0.1:8100'
SERVER_ADDR = 'tcp://127.0.0.1:8100' SERVER_ADDR = 'tcp://10.200.27.140:8100'
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms # 发送间隔: 每包 5 采样点 / 250Hz = 20ms
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS PKT_INTERVAL = N_SAMPLES_PER_PKT / FS