打标epoch_finished, event_inner_index,epoch data 走队列
This commit is contained in:
137
Decoder.py
137
Decoder.py
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import time
|
||||
import torch
|
||||
from queue import Empty
|
||||
import queue
|
||||
from scipy import signal
|
||||
from torch.autograd import Variable
|
||||
# from Device.SunnyLinker import SunnyLinker64
|
||||
@@ -276,32 +277,28 @@ class Decoder_main(threading.Thread):
|
||||
|
||||
'''训练阶段采集数据'''
|
||||
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||
if self.zmqServer.pack_contain_event:
|
||||
with self.zmqServer.paradigmBufferLock:
|
||||
self.zmqServer.paradigmBuffer.resetAllPara()
|
||||
self.zmqServer.pack_contain_event = False
|
||||
|
||||
if self.zmqServer.epoch_finished:
|
||||
data_length = self.zmqServer.paradigmBuffer.GetDataLenCount()
|
||||
if data_length >= self.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||
algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||
self.trainLabel, list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||
else:
|
||||
algo_log(f"epoch_finished {self.zmqServer.epoch_finished}, 数据长度不足 {data_length}", level="DEBUG")
|
||||
self.zmqServer.epoch_finished = False
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
try:
|
||||
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
trainTrial = epoch_payload['snapshot']
|
||||
event_inner_idx = epoch_payload['event_inner_idx']
|
||||
self.currentLabel = epoch_payload['currentLabel']
|
||||
data_length = epoch_payload['data_length']
|
||||
if trainTrial is None or data_length < self.train_epoch[1] + event_inner_idx:
|
||||
algo_log(f"SSMVEP epoch数据长度不足: {data_length}, 跳过", level="WARNING")
|
||||
return
|
||||
algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, event_inner_idx]}", level="DEBUG")
|
||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, event_inner_idx + self.train_epoch[
|
||||
0]:event_inner_idx + self.train_epoch[1]]
|
||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||
self.trainLabel, list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||
|
||||
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
||||
if self.load_model == False: # 模型尚未训练完成
|
||||
@@ -313,18 +310,22 @@ class Decoder_main(threading.Thread):
|
||||
now = datetime.now()
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
# algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG")
|
||||
try:
|
||||
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
|
||||
algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
data = epoch_payload['snapshot']
|
||||
event_inner_idx = epoch_payload['event_inner_idx']
|
||||
data_length = epoch_payload['data_length']
|
||||
if data is None or data_length < self.interval_epoch[1] + event_inner_idx:
|
||||
algo_log(f"SSMVEP predict epoch数据长度不足: {data_length}, 跳过", level="WARNING")
|
||||
return
|
||||
algo_log(f"取出的:{data.shape}, event: {data[-2, event_inner_idx]}", level="DEBUG")
|
||||
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
||||
data = data[:,
|
||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
event_inner_idx + self.interval_epoch[
|
||||
0]:event_inner_idx + self.interval_epoch[1]]
|
||||
pad_eeg_test = np.zeros(
|
||||
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
|
||||
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
|
||||
@@ -384,28 +385,33 @@ class Decoder_main(threading.Thread):
|
||||
|
||||
'''训练阶段采集数据'''
|
||||
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
|
||||
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||
self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||
self.currentLabel = self.zmqServer.currentLabel # 同步当前标签
|
||||
algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
|
||||
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
||||
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
# algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
|
||||
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||
list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||
self.plotLabel.append(self.currentLabel)
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
try:
|
||||
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
originalTrial = epoch_payload['snapshot']
|
||||
event_inner_idx = epoch_payload['event_inner_idx']
|
||||
self.currentLabel = epoch_payload['currentLabel']
|
||||
data_length = epoch_payload['data_length']
|
||||
if originalTrial is None or data_length < self.zmqServer.train_epoch[1] + event_inner_idx:
|
||||
algo_log(f"epoch数据长度不足: {data_length}, 跳过", level="WARNING")
|
||||
return
|
||||
algo_log(f"训练队列数据:{data_length}", level="DEBUG")
|
||||
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, event_inner_idx]}", level="DEBUG")
|
||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, event_inner_idx + self.interval_epoch[
|
||||
0]:event_inner_idx + self.interval_epoch[1]]
|
||||
# algo_log(f"trial: {event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
|
||||
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||
list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||
self.plotData.append(originalTrial[:self.n_chan, event_inner_idx + self.interval_epoch[
|
||||
0]:event_inner_idx + self.interval_epoch[1]])
|
||||
self.plotLabel.append(self.currentLabel)
|
||||
|
||||
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
||||
if self.zmqServer.StartDecode:
|
||||
@@ -414,21 +420,26 @@ class Decoder_main(threading.Thread):
|
||||
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||
algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
|
||||
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.interval_epoch[1] \
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
try:
|
||||
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
time.sleep(0.001)
|
||||
return
|
||||
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
|
||||
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||
originalData = epoch_payload['snapshot']
|
||||
event_inner_idx = epoch_payload['event_inner_idx']
|
||||
data_length = epoch_payload['data_length']
|
||||
if originalData is None or data_length < self.interval_epoch[1] + event_inner_idx:
|
||||
algo_log(f"predict epoch数据长度不足: {data_length}, 跳过", level="WARNING")
|
||||
return
|
||||
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, event_inner_idx]}", level="DEBUG")
|
||||
start = time.time()
|
||||
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
||||
data = data[:,
|
||||
self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||
event_inner_idx + self.interval_epoch[
|
||||
0]:event_inner_idx + self.interval_epoch[1]]
|
||||
self.plotData.append(
|
||||
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
|
||||
originalData[:self.n_chan, event_inner_idx + self.interval_epoch[
|
||||
0]:event_inner_idx + self.interval_epoch[1]])
|
||||
|
||||
test_data = data[np.newaxis, np.newaxis, :, :]
|
||||
test_data = torch.from_numpy(test_data)
|
||||
|
||||
@@ -84,6 +84,10 @@ class zmqServer(threading.Thread):
|
||||
self.cmd_send_queue = queue.Queue() # 8099端口命令结果队列
|
||||
self.data_send_queue = queue.Queue() # 8100端口滤波数据队列
|
||||
|
||||
# epoch完成通知队列:生产者(zmqServer)写入, 消费者(Decoder)读取
|
||||
# 每个元素是一个dict,包含完整的epoch数据快照,避免裸标志位竞态
|
||||
self.epoch_queue = queue.Queue(maxsize=10)
|
||||
|
||||
# 范式buffer与事件检测参数
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
@@ -105,6 +109,12 @@ class zmqServer(threading.Thread):
|
||||
self.pack_contain_event = False
|
||||
self.event_inner_idx = -1
|
||||
self.interval_inited = False
|
||||
# 清空epoch队列,防止残留旧epoch被新阶段消费
|
||||
while not self.epoch_queue.empty():
|
||||
try:
|
||||
self.epoch_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
def interval_init(self, decoder_class):
|
||||
if decoder_class == 'ssmvep':
|
||||
@@ -132,6 +142,12 @@ class zmqServer(threading.Thread):
|
||||
self.predict_event = 99
|
||||
self.events = [1, 2, self.predict_event]
|
||||
self.interval_inited = True
|
||||
# 清空epoch队列,防止旧范式的残留epoch被新阶段消费
|
||||
while not self.epoch_queue.empty():
|
||||
try:
|
||||
self.epoch_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# -------------------------- 8099端口:命令结果广播 --------------------------
|
||||
def broadcast_message(self, method, params):
|
||||
@@ -341,10 +357,60 @@ class zmqServer(threading.Thread):
|
||||
with self.paradigmBufferLock:
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
if self.interval_inited:
|
||||
self.pack_contain_event, self.epoch_finished = self.detect_event(data_np)
|
||||
if self.epoch_finished:
|
||||
algo_log(f"Epoch采集完成, 当前数据长度{self.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
|
||||
self.epoch_finished = self.detect_event(data_np)
|
||||
if self.pack_contain_event:
|
||||
self.paradigmBuffer.resetAllPara()
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
|
||||
if self.epoch_finished:
|
||||
now = datetime.datetime.now()
|
||||
time_diff_str = ""
|
||||
# 计算与上一次Epoch完成的时间差
|
||||
if self.last_epoch_finish_time is not None:
|
||||
delta_seconds = (now - self.last_epoch_finish_time).total_seconds()
|
||||
time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s"
|
||||
log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}"
|
||||
algo_log(log_msg, level="DEBUG")
|
||||
self.last_epoch_finish_time = now
|
||||
|
||||
# ---- 队列化通知:在锁内快照数据,原子放入epoch_queue ----
|
||||
# get_MIData / get_SSMVEPData 内部调用 getData(),返回连续数组副本
|
||||
# 快照完成后再由 Decoder 线程消费,彻底消除 TOCTOU 竞态
|
||||
data_length = self.paradigmBuffer.GetDataLenCount()
|
||||
try:
|
||||
if self.decoder_class == 'mi':
|
||||
snapshot = self.paradigmBuffer.get_MIData()
|
||||
elif self.decoder_class == 'ssmvep':
|
||||
snapshot = self.paradigmBuffer.get_SSMVEPData()
|
||||
else:
|
||||
snapshot = None
|
||||
except Exception as snap_err:
|
||||
algo_log(f"epoch快照失败: {snap_err}", level="ERROR")
|
||||
snapshot = None
|
||||
|
||||
epoch_payload = {
|
||||
'snapshot': snapshot, # 已选通道的数据副本
|
||||
'currentLabel': self.currentLabel,
|
||||
'event_inner_idx': self.event_inner_idx,
|
||||
'data_length': data_length,
|
||||
'completion_time': now,
|
||||
'decoder_class': self.decoder_class,
|
||||
}
|
||||
try:
|
||||
self.epoch_queue.put_nowait(epoch_payload)
|
||||
except queue.Full:
|
||||
# 队列满时丢弃最旧的epoch,保留最新
|
||||
try:
|
||||
self.epoch_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
try:
|
||||
self.epoch_queue.put_nowait(epoch_payload)
|
||||
algo_log("epoch_queue已满,已丢弃最旧epoch", level="WARNING")
|
||||
except queue.Full:
|
||||
algo_log("epoch_queue放入失败", level="ERROR")
|
||||
else:
|
||||
self.paradigmBuffer.appendBuffer(data_np)
|
||||
|
||||
except Exception as e:
|
||||
algo_log(f"数据处理失败: {str(e)}", level="ERROR")
|
||||
|
||||
Reference in New Issue
Block a user