打标epoch_finished, event_inner_index,epoch data 走队列

This commit is contained in:
Ivey Song
2026-06-14 19:32:48 +08:00
parent 5d3cd0dba9
commit bfd3fa27b3
2 changed files with 143 additions and 66 deletions

View File

@@ -9,6 +9,7 @@ import numpy as np
import time
import torch
from queue import Empty
import queue
from scipy import signal
from torch.autograd import Variable
# from Device.SunnyLinker import SunnyLinker64
@@ -276,32 +277,28 @@ class Decoder_main(threading.Thread):
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.pack_contain_event:
with self.zmqServer.paradigmBufferLock:
self.zmqServer.paradigmBuffer.resetAllPara()
self.zmqServer.pack_contain_event = False
if self.zmqServer.epoch_finished:
data_length = self.zmqServer.paradigmBuffer.GetDataLenCount()
if data_length >= self.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
time.sleep(0.0001)
return
trainTrial = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
self.currentLabel = epoch_payload['currentLabel']
data_length = epoch_payload['data_length']
if trainTrial is None or data_length < self.train_epoch[1] + event_inner_idx:
algo_log(f"SSMVEP epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
trainTrial = trainTrial[:, event_inner_idx + self.train_epoch[
0]:event_inner_idx + self.train_epoch[1]]
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG")
else:
algo_log(f"epoch_finished {self.zmqServer.epoch_finished}, 数据长度不足 {data_length}", level="DEBUG")
self.zmqServer.epoch_finished = False
else:
time.sleep(0.001)
return
elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成
@@ -313,18 +310,22 @@ class Decoder_main(threading.Thread):
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.zmqServer.event_inner_idx:
# algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG")
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
time.sleep(0.0001)
return
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
data = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
data_length = epoch_payload['data_length']
if data is None or data_length < self.interval_epoch[1] + event_inner_idx:
algo_log(f"SSMVEP predict epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{data.shape}, event: {data[-2, event_inner_idx]}", level="DEBUG")
data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:,
self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]]
pad_eeg_test = np.zeros(
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
@@ -384,28 +385,33 @@ class Decoder_main(threading.Thread):
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel # 同步当前标签
algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
time.sleep(0.0001)
return
originalTrial = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
self.currentLabel = epoch_payload['currentLabel']
data_length = epoch_payload['data_length']
if originalTrial is None or data_length < self.zmqServer.train_epoch[1] + event_inner_idx:
algo_log(f"epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"训练队列数据:{data_length}", level="DEBUG")
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
# algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
trainTrial = trainTrial[:, event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]]
# algo_log(f"trial: {event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
self.plotData.append(originalTrial[:self.n_chan, event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
else:
time.sleep(0.001)
return
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode:
@@ -414,21 +420,26 @@ class Decoder_main(threading.Thread):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.zmqServer.event_inner_idx:
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
time.sleep(0.001)
return
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
originalData = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
data_length = epoch_payload['data_length']
if originalData is None or data_length < self.interval_epoch[1] + event_inner_idx:
algo_log(f"predict epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, event_inner_idx]}", level="DEBUG")
start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:,
self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]]
self.plotData.append(
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
originalData[:self.n_chan, event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]])
test_data = data[np.newaxis, np.newaxis, :, :]
test_data = torch.from_numpy(test_data)

View File

@@ -84,6 +84,10 @@ class zmqServer(threading.Thread):
self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列
self.data_send_queue = queue.Queue() # 8100端口滤波数据队列
# epoch完成通知队列生产者(zmqServer)写入, 消费者(Decoder)读取
# 每个元素是一个dict包含完整的epoch数据快照避免裸标志位竞态
self.epoch_queue = queue.Queue(maxsize=10)
# 范式buffer与事件检测参数
self.predict_event = 99
self.events = [1, 2, self.predict_event]
@@ -105,6 +109,12 @@ class zmqServer(threading.Thread):
self.pack_contain_event = False
self.event_inner_idx = -1
self.interval_inited = False
# 清空epoch队列防止残留旧epoch被新阶段消费
while not self.epoch_queue.empty():
try:
self.epoch_queue.get_nowait()
except queue.Empty:
break
def interval_init(self, decoder_class):
if decoder_class == 'ssmvep':
@@ -132,6 +142,12 @@ class zmqServer(threading.Thread):
self.predict_event = 99
self.events = [1, 2, self.predict_event]
self.interval_inited = True
# 清空epoch队列防止旧范式的残留epoch被新阶段消费
while not self.epoch_queue.empty():
try:
self.epoch_queue.get_nowait()
except queue.Empty:
break
# -------------------------- 8099端口命令结果广播 --------------------------
def broadcast_message(self, method, params):
@@ -341,10 +357,60 @@ class zmqServer(threading.Thread):
with self.paradigmBufferLock:
self.paradigmBuffer.appendBuffer(data_np)
if self.interval_inited:
self.pack_contain_event, self.epoch_finished = self.detect_event(data_np)
if self.epoch_finished:
algo_log(f"Epoch采集完成, 当前数据长度{self.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
self.epoch_finished = self.detect_event(data_np)
if self.pack_contain_event:
self.paradigmBuffer.resetAllPara()
self.paradigmBuffer.appendBuffer(data_np)
if self.epoch_finished:
now = datetime.datetime.now()
time_diff_str = ""
# 计算与上一次Epoch完成的时间差
if self.last_epoch_finish_time is not None:
delta_seconds = (now - self.last_epoch_finish_time).total_seconds()
time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s"
log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}"
algo_log(log_msg, level="DEBUG")
self.last_epoch_finish_time = now
# ---- 队列化通知在锁内快照数据原子放入epoch_queue ----
# get_MIData / get_SSMVEPData 内部调用 getData(),返回连续数组副本
# 快照完成后再由 Decoder 线程消费,彻底消除 TOCTOU 竞态
data_length = self.paradigmBuffer.GetDataLenCount()
try:
if self.decoder_class == 'mi':
snapshot = self.paradigmBuffer.get_MIData()
elif self.decoder_class == 'ssmvep':
snapshot = self.paradigmBuffer.get_SSMVEPData()
else:
snapshot = None
except Exception as snap_err:
algo_log(f"epoch快照失败: {snap_err}", level="ERROR")
snapshot = None
epoch_payload = {
'snapshot': snapshot, # 已选通道的数据副本
'currentLabel': self.currentLabel,
'event_inner_idx': self.event_inner_idx,
'data_length': data_length,
'completion_time': now,
'decoder_class': self.decoder_class,
}
try:
self.epoch_queue.put_nowait(epoch_payload)
except queue.Full:
# 队列满时丢弃最旧的epoch保留最新
try:
self.epoch_queue.get_nowait()
except queue.Empty:
pass
try:
self.epoch_queue.put_nowait(epoch_payload)
algo_log("epoch_queue已满已丢弃最旧epoch", level="WARNING")
except queue.Full:
algo_log("epoch_queue放入失败", level="ERROR")
else:
self.paradigmBuffer.appendBuffer(data_np)
except Exception as e:
algo_log(f"数据处理失败: {str(e)}", level="ERROR")