diff --git a/Decoder.py b/Decoder.py index 6b83fed..85b6930 100644 --- a/Decoder.py +++ b/Decoder.py @@ -9,6 +9,7 @@ import numpy as np import time import torch from queue import Empty +import queue from scipy import signal from torch.autograd import Variable # from Device.SunnyLinker import SunnyLinker64 @@ -276,32 +277,28 @@ class Decoder_main(threading.Thread): '''训练阶段采集数据''' if self.zmqServer.state_mode == 'train': # 训练状态 - if self.zmqServer.pack_contain_event: - with self.zmqServer.paradigmBufferLock: - self.zmqServer.paradigmBuffer.resetAllPara() - self.zmqServer.pack_contain_event = False - - if self.zmqServer.epoch_finished: - data_length = self.zmqServer.paradigmBuffer.GetDataLenCount() - if data_length >= self.train_epoch[1] + self.zmqServer.event_inner_idx: - self.currentLabel = self.zmqServer.currentLabel - trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据 - algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") - trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 - trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[ - 0]:self.zmqServer.event_inner_idx + self.train_epoch[1]] - if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( - self.trainLabel, list) \ - and self.trainLabel.count(self.currentLabel) < self.single_train: - self.trainData.append(trainTrial) - self.trainLabel.append(self.currentLabel) - algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG") - else: - algo_log(f"epoch_finished {self.zmqServer.epoch_finished}, 数据长度不足 {data_length}", level="DEBUG") - self.zmqServer.epoch_finished = False - else: - time.sleep(0.001) + try: + epoch_payload = self.zmqServer.epoch_queue.get_nowait() + except queue.Empty: + time.sleep(0.0001) return + trainTrial = epoch_payload['snapshot'] + event_inner_idx = epoch_payload['event_inner_idx'] + self.currentLabel = epoch_payload['currentLabel'] + data_length = epoch_payload['data_length'] + if trainTrial is None or data_length < self.train_epoch[1] + event_inner_idx: + algo_log(f"SSMVEP epoch数据长度不足: {data_length}, 跳过", level="WARNING") + return + algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, event_inner_idx]}", level="DEBUG") + trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 + trainTrial = trainTrial[:, event_inner_idx + self.train_epoch[ + 0]:event_inner_idx + self.train_epoch[1]] + if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( + self.trainLabel, list) \ + and self.trainLabel.count(self.currentLabel) < self.single_train: + self.trainData.append(trainTrial) + self.trainLabel.append(self.currentLabel) + algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG") elif self.zmqServer.state_mode == 'predict': # 测试状态 if self.load_model == False: # 模型尚未训练完成 @@ -313,18 +310,22 @@ class Decoder_main(threading.Thread): now = datetime.now() formatted_time = now.strftime('%H:%M:%S.%f')[:-3] algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG") - if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ - self.interval_epoch[1] \ - + self.zmqServer.event_inner_idx: - # algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG") + try: + epoch_payload = self.zmqServer.epoch_queue.get_nowait() + except queue.Empty: time.sleep(0.0001) return - data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据 - algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") + data = epoch_payload['snapshot'] + event_inner_idx = epoch_payload['event_inner_idx'] + data_length = epoch_payload['data_length'] + if data is None or data_length < self.interval_epoch[1] + event_inner_idx: + algo_log(f"SSMVEP predict epoch数据长度不足: {data_length}, 跳过", level="WARNING") + return + algo_log(f"取出的:{data.shape}, event: {data[-2, event_inner_idx]}", level="DEBUG") data = self.preprocess(data[:self.n_chan, :]) # 预处理 data = data[:, - self.zmqServer.event_inner_idx + self.interval_epoch[ - 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] + event_inner_idx + self.interval_epoch[ + 0]:event_inner_idx + self.interval_epoch[1]] pad_eeg_test = np.zeros( (data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate']))) pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data @@ -384,28 +385,33 @@ class Decoder_main(threading.Thread): '''训练阶段采集数据''' if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态 - if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \ - self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx: - self.currentLabel = self.zmqServer.currentLabel # 同步当前标签 - algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG") - originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据 - algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") - trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 - trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[ - 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] - # algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG") - if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel, - list) \ - and self.trainLabel.count(self.currentLabel) < self.single_train: - self.trainData.append(trainTrial) - self.trainLabel.append(self.currentLabel) - algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG") - self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[ - 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]) - self.plotLabel.append(self.currentLabel) - else: - time.sleep(0.001) + try: + epoch_payload = self.zmqServer.epoch_queue.get_nowait() + except queue.Empty: + time.sleep(0.0001) return + originalTrial = epoch_payload['snapshot'] + event_inner_idx = epoch_payload['event_inner_idx'] + self.currentLabel = epoch_payload['currentLabel'] + data_length = epoch_payload['data_length'] + if originalTrial is None or data_length < self.zmqServer.train_epoch[1] + event_inner_idx: + algo_log(f"epoch数据长度不足: {data_length}, 跳过", level="WARNING") + return + algo_log(f"训练队列数据:{data_length}", level="DEBUG") + algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, event_inner_idx]}", level="DEBUG") + trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 + trainTrial = trainTrial[:, event_inner_idx + self.interval_epoch[ + 0]:event_inner_idx + self.interval_epoch[1]] + # algo_log(f"trial: {event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG") + if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel, + list) \ + and self.trainLabel.count(self.currentLabel) < self.single_train: + self.trainData.append(trainTrial) + self.trainLabel.append(self.currentLabel) + algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG") + self.plotData.append(originalTrial[:self.n_chan, event_inner_idx + self.interval_epoch[ + 0]:event_inner_idx + self.interval_epoch[1]]) + self.plotLabel.append(self.currentLabel) elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态 if self.zmqServer.StartDecode: @@ -414,21 +420,26 @@ class Decoder_main(threading.Thread): formatted_time = now.strftime('%H:%M:%S.%f')[:-3] algo_log(f"MI启动预测 {formatted_time}", level="DEBUG") - if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ - self.interval_epoch[1] \ - + self.zmqServer.event_inner_idx: + try: + epoch_payload = self.zmqServer.epoch_queue.get_nowait() + except queue.Empty: time.sleep(0.001) return - originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据 - algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") + originalData = epoch_payload['snapshot'] + event_inner_idx = epoch_payload['event_inner_idx'] + data_length = epoch_payload['data_length'] + if originalData is None or data_length < self.interval_epoch[1] + event_inner_idx: + algo_log(f"predict epoch数据长度不足: {data_length}, 跳过", level="WARNING") + return + algo_log(f"取出的:{originalData.shape},event: {originalData[-2, event_inner_idx]}", level="DEBUG") start = time.time() data = self.preprocess(originalData[:self.n_chan, :]) # 预处理 data = data[:, - self.zmqServer.event_inner_idx + self.interval_epoch[ - 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] + event_inner_idx + self.interval_epoch[ + 0]:event_inner_idx + self.interval_epoch[1]] self.plotData.append( - originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[ - 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]) + originalData[:self.n_chan, event_inner_idx + self.interval_epoch[ + 0]:event_inner_idx + self.interval_epoch[1]]) test_data = data[np.newaxis, np.newaxis, :, :] test_data = torch.from_numpy(test_data) diff --git a/Zmq/zmqServer.py b/Zmq/zmqServer.py index 15cc266..c6bbef0 100644 --- a/Zmq/zmqServer.py +++ b/Zmq/zmqServer.py @@ -83,6 +83,10 @@ class zmqServer(threading.Thread): # 发送队列(双端口分离) self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列 self.data_send_queue = queue.Queue() # 8100端口滤波数据队列 + + # epoch完成通知队列:生产者(zmqServer)写入, 消费者(Decoder)读取 + # 每个元素是一个dict,包含完整的epoch数据快照,避免裸标志位竞态 + self.epoch_queue = queue.Queue(maxsize=10) # 范式buffer与事件检测参数 self.predict_event = 99 @@ -105,6 +109,12 @@ class zmqServer(threading.Thread): self.pack_contain_event = False self.event_inner_idx = -1 self.interval_inited = False + # 清空epoch队列,防止残留旧epoch被新阶段消费 + while not self.epoch_queue.empty(): + try: + self.epoch_queue.get_nowait() + except queue.Empty: + break def interval_init(self, decoder_class): if decoder_class == 'ssmvep': @@ -132,6 +142,12 @@ class zmqServer(threading.Thread): self.predict_event = 99 self.events = [1, 2, self.predict_event] self.interval_inited = True + # 清空epoch队列,防止旧范式的残留epoch被新阶段消费 + while not self.epoch_queue.empty(): + try: + self.epoch_queue.get_nowait() + except queue.Empty: + break # -------------------------- 8099端口:命令结果广播 -------------------------- def broadcast_message(self, method, params): @@ -341,10 +357,60 @@ class zmqServer(threading.Thread): with self.paradigmBufferLock: self.paradigmBuffer.appendBuffer(data_np) if self.interval_inited: - self.pack_contain_event, self.epoch_finished = self.detect_event(data_np) - if self.epoch_finished: - algo_log(f"Epoch采集完成, 当前数据长度{self.paradigmBuffer.GetDataLenCount()}", level="DEBUG") + self.epoch_finished = self.detect_event(data_np) + if self.pack_contain_event: + self.paradigmBuffer.resetAllPara() + self.paradigmBuffer.appendBuffer(data_np) + if self.epoch_finished: + now = datetime.datetime.now() + time_diff_str = "" + # 计算与上一次Epoch完成的时间差 + if self.last_epoch_finish_time is not None: + delta_seconds = (now - self.last_epoch_finish_time).total_seconds() + time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s" + log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}" + algo_log(log_msg, level="DEBUG") + self.last_epoch_finish_time = now + + # ---- 队列化通知:在锁内快照数据,原子放入epoch_queue ---- + # get_MIData / get_SSMVEPData 内部调用 getData(),返回连续数组副本 + # 快照完成后再由 Decoder 线程消费,彻底消除 TOCTOU 竞态 + data_length = self.paradigmBuffer.GetDataLenCount() + try: + if self.decoder_class == 'mi': + snapshot = self.paradigmBuffer.get_MIData() + elif self.decoder_class == 'ssmvep': + snapshot = self.paradigmBuffer.get_SSMVEPData() + else: + snapshot = None + except Exception as snap_err: + algo_log(f"epoch快照失败: {snap_err}", level="ERROR") + snapshot = None + + epoch_payload = { + 'snapshot': snapshot, # 已选通道的数据副本 + 'currentLabel': self.currentLabel, + 'event_inner_idx': self.event_inner_idx, + 'data_length': data_length, + 'completion_time': now, + 'decoder_class': self.decoder_class, + } + try: + self.epoch_queue.put_nowait(epoch_payload) + except queue.Full: + # 队列满时丢弃最旧的epoch,保留最新 + try: + self.epoch_queue.get_nowait() + except queue.Empty: + pass + try: + self.epoch_queue.put_nowait(epoch_payload) + algo_log("epoch_queue已满,已丢弃最旧epoch", level="WARNING") + except queue.Full: + algo_log("epoch_queue放入失败", level="ERROR") + else: + self.paradigmBuffer.appendBuffer(data_np) except Exception as e: algo_log(f"数据处理失败: {str(e)}", level="ERROR")