打标epoch_finished, event_inner_index,epoch data 走队列

This commit is contained in:
Ivey Song
2026-06-14 19:32:48 +08:00
parent 5d3cd0dba9
commit bfd3fa27b3
2 changed files with 143 additions and 66 deletions

View File

@@ -9,6 +9,7 @@ import numpy as np
import time import time
import torch import torch
from queue import Empty from queue import Empty
import queue
from scipy import signal from scipy import signal
from torch.autograd import Variable from torch.autograd import Variable
# from Device.SunnyLinker import SunnyLinker64 # from Device.SunnyLinker import SunnyLinker64
@@ -276,32 +277,28 @@ class Decoder_main(threading.Thread):
'''训练阶段采集数据''' '''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态 if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.pack_contain_event: try:
with self.zmqServer.paradigmBufferLock: epoch_payload = self.zmqServer.epoch_queue.get_nowait()
self.zmqServer.paradigmBuffer.resetAllPara() except queue.Empty:
self.zmqServer.pack_contain_event = False time.sleep(0.0001)
return
if self.zmqServer.epoch_finished: trainTrial = epoch_payload['snapshot']
data_length = self.zmqServer.paradigmBuffer.GetDataLenCount() event_inner_idx = epoch_payload['event_inner_idx']
if data_length >= self.train_epoch[1] + self.zmqServer.event_inner_idx: self.currentLabel = epoch_payload['currentLabel']
self.currentLabel = self.zmqServer.currentLabel data_length = epoch_payload['data_length']
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据 if trainTrial is None or data_length < self.train_epoch[1] + event_inner_idx:
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") algo_log(f"SSMVEP epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[ trainTrial = trainTrial[:, event_inner_idx + self.train_epoch[
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]] 0]:event_inner_idx + self.train_epoch[1]]
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \ self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train: and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial) self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel) self.trainLabel.append(self.currentLabel)
algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG") algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG")
else:
algo_log(f"epoch_finished {self.zmqServer.epoch_finished}, 数据长度不足 {data_length}", level="DEBUG")
self.zmqServer.epoch_finished = False
else:
time.sleep(0.001)
return
elif self.zmqServer.state_mode == 'predict': # 测试状态 elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成 if self.load_model == False: # 模型尚未训练完成
@@ -313,18 +310,22 @@ class Decoder_main(threading.Thread):
now = datetime.now() now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG") algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ try:
self.interval_epoch[1] \ epoch_payload = self.zmqServer.epoch_queue.get_nowait()
+ self.zmqServer.event_inner_idx: except queue.Empty:
# algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG")
time.sleep(0.0001) time.sleep(0.0001)
return return
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据 data = epoch_payload['snapshot']
algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") event_inner_idx = epoch_payload['event_inner_idx']
data_length = epoch_payload['data_length']
if data is None or data_length < self.interval_epoch[1] + event_inner_idx:
algo_log(f"SSMVEP predict epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{data.shape}, event: {data[-2, event_inner_idx]}", level="DEBUG")
data = self.preprocess(data[:self.n_chan, :]) # 预处理 data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:, data = data[:,
self.zmqServer.event_inner_idx + self.interval_epoch[ event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] 0]:event_inner_idx + self.interval_epoch[1]]
pad_eeg_test = np.zeros( pad_eeg_test = np.zeros(
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate']))) (data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
@@ -384,28 +385,33 @@ class Decoder_main(threading.Thread):
'''训练阶段采集数据''' '''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态 if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \ try:
self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx: epoch_payload = self.zmqServer.epoch_queue.get_nowait()
self.currentLabel = self.zmqServer.currentLabel # 同步当前标签 except queue.Empty:
algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG") time.sleep(0.0001)
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据 return
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") originalTrial = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
self.currentLabel = epoch_payload['currentLabel']
data_length = epoch_payload['data_length']
if originalTrial is None or data_length < self.zmqServer.train_epoch[1] + event_inner_idx:
algo_log(f"epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"训练队列数据:{data_length}", level="DEBUG")
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[ trainTrial = trainTrial[:, event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] 0]:event_inner_idx + self.interval_epoch[1]]
# algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG") # algo_log(f"trial: {event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel, if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \ list) \
and self.trainLabel.count(self.currentLabel) < self.single_train: and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial) self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel) self.trainLabel.append(self.currentLabel)
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG") algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[ self.plotData.append(originalTrial[:self.n_chan, event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]) 0]:event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel) self.plotLabel.append(self.currentLabel)
else:
time.sleep(0.001)
return
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态 elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode: if self.zmqServer.StartDecode:
@@ -414,21 +420,26 @@ class Decoder_main(threading.Thread):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3] formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
algo_log(f"MI启动预测 {formatted_time}", level="DEBUG") algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ try:
self.interval_epoch[1] \ epoch_payload = self.zmqServer.epoch_queue.get_nowait()
+ self.zmqServer.event_inner_idx: except queue.Empty:
time.sleep(0.001) time.sleep(0.001)
return return
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据 originalData = epoch_payload['snapshot']
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") event_inner_idx = epoch_payload['event_inner_idx']
data_length = epoch_payload['data_length']
if originalData is None or data_length < self.interval_epoch[1] + event_inner_idx:
algo_log(f"predict epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, event_inner_idx]}", level="DEBUG")
start = time.time() start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理 data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:, data = data[:,
self.zmqServer.event_inner_idx + self.interval_epoch[ event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] 0]:event_inner_idx + self.interval_epoch[1]]
self.plotData.append( self.plotData.append(
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[ originalData[:self.n_chan, event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]) 0]:event_inner_idx + self.interval_epoch[1]])
test_data = data[np.newaxis, np.newaxis, :, :] test_data = data[np.newaxis, np.newaxis, :, :]
test_data = torch.from_numpy(test_data) test_data = torch.from_numpy(test_data)

View File

@@ -84,6 +84,10 @@ class zmqServer(threading.Thread):
self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列 self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列
self.data_send_queue = queue.Queue() # 8100端口滤波数据队列 self.data_send_queue = queue.Queue() # 8100端口滤波数据队列
# epoch完成通知队列生产者(zmqServer)写入, 消费者(Decoder)读取
# 每个元素是一个dict包含完整的epoch数据快照避免裸标志位竞态
self.epoch_queue = queue.Queue(maxsize=10)
# 范式buffer与事件检测参数 # 范式buffer与事件检测参数
self.predict_event = 99 self.predict_event = 99
self.events = [1, 2, self.predict_event] self.events = [1, 2, self.predict_event]
@@ -105,6 +109,12 @@ class zmqServer(threading.Thread):
self.pack_contain_event = False self.pack_contain_event = False
self.event_inner_idx = -1 self.event_inner_idx = -1
self.interval_inited = False self.interval_inited = False
# 清空epoch队列防止残留旧epoch被新阶段消费
while not self.epoch_queue.empty():
try:
self.epoch_queue.get_nowait()
except queue.Empty:
break
def interval_init(self, decoder_class): def interval_init(self, decoder_class):
if decoder_class == 'ssmvep': if decoder_class == 'ssmvep':
@@ -132,6 +142,12 @@ class zmqServer(threading.Thread):
self.predict_event = 99 self.predict_event = 99
self.events = [1, 2, self.predict_event] self.events = [1, 2, self.predict_event]
self.interval_inited = True self.interval_inited = True
# 清空epoch队列防止旧范式的残留epoch被新阶段消费
while not self.epoch_queue.empty():
try:
self.epoch_queue.get_nowait()
except queue.Empty:
break
# -------------------------- 8099端口命令结果广播 -------------------------- # -------------------------- 8099端口命令结果广播 --------------------------
def broadcast_message(self, method, params): def broadcast_message(self, method, params):
@@ -341,10 +357,60 @@ class zmqServer(threading.Thread):
with self.paradigmBufferLock: with self.paradigmBufferLock:
self.paradigmBuffer.appendBuffer(data_np) self.paradigmBuffer.appendBuffer(data_np)
if self.interval_inited: if self.interval_inited:
self.pack_contain_event, self.epoch_finished = self.detect_event(data_np) self.epoch_finished = self.detect_event(data_np)
if self.epoch_finished: if self.pack_contain_event:
algo_log(f"Epoch采集完成, 当前数据长度{self.paradigmBuffer.GetDataLenCount()}", level="DEBUG") self.paradigmBuffer.resetAllPara()
self.paradigmBuffer.appendBuffer(data_np)
if self.epoch_finished:
now = datetime.datetime.now()
time_diff_str = ""
# 计算与上一次Epoch完成的时间差
if self.last_epoch_finish_time is not None:
delta_seconds = (now - self.last_epoch_finish_time).total_seconds()
time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s"
log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}"
algo_log(log_msg, level="DEBUG")
self.last_epoch_finish_time = now
# ---- 队列化通知在锁内快照数据原子放入epoch_queue ----
# get_MIData / get_SSMVEPData 内部调用 getData(),返回连续数组副本
# 快照完成后再由 Decoder 线程消费,彻底消除 TOCTOU 竞态
data_length = self.paradigmBuffer.GetDataLenCount()
try:
if self.decoder_class == 'mi':
snapshot = self.paradigmBuffer.get_MIData()
elif self.decoder_class == 'ssmvep':
snapshot = self.paradigmBuffer.get_SSMVEPData()
else:
snapshot = None
except Exception as snap_err:
algo_log(f"epoch快照失败: {snap_err}", level="ERROR")
snapshot = None
epoch_payload = {
'snapshot': snapshot, # 已选通道的数据副本
'currentLabel': self.currentLabel,
'event_inner_idx': self.event_inner_idx,
'data_length': data_length,
'completion_time': now,
'decoder_class': self.decoder_class,
}
try:
self.epoch_queue.put_nowait(epoch_payload)
except queue.Full:
# 队列满时丢弃最旧的epoch保留最新
try:
self.epoch_queue.get_nowait()
except queue.Empty:
pass
try:
self.epoch_queue.put_nowait(epoch_payload)
algo_log("epoch_queue已满已丢弃最旧epoch", level="WARNING")
except queue.Full:
algo_log("epoch_queue放入失败", level="ERROR")
else:
self.paradigmBuffer.appendBuffer(data_np)
except Exception as e: except Exception as e:
algo_log(f"数据处理失败: {str(e)}", level="ERROR") algo_log(f"数据处理失败: {str(e)}", level="ERROR")