打标epoch_finished, event_inner_index,epoch data 走队列

This commit is contained in:
Ivey Song
2026-06-14 19:32:48 +08:00
parent 5d3cd0dba9
commit bfd3fa27b3
2 changed files with 143 additions and 66 deletions

View File

@@ -9,6 +9,7 @@ import numpy as np
import time
import torch
from queue import Empty
import queue
from scipy import signal
from torch.autograd import Variable
# from Device.SunnyLinker import SunnyLinker64
@@ -276,32 +277,28 @@ class Decoder_main(threading.Thread):
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train': # 训练状态
if self.zmqServer.pack_contain_event:
with self.zmqServer.paradigmBufferLock:
self.zmqServer.paradigmBuffer.resetAllPara()
self.zmqServer.pack_contain_event = False
if self.zmqServer.epoch_finished:
data_length = self.zmqServer.paradigmBuffer.GetDataLenCount()
if data_length >= self.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG")
else:
algo_log(f"epoch_finished {self.zmqServer.epoch_finished}, 数据长度不足 {data_length}", level="DEBUG")
self.zmqServer.epoch_finished = False
else:
time.sleep(0.001)
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
time.sleep(0.0001)
return
trainTrial = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
self.currentLabel = epoch_payload['currentLabel']
data_length = epoch_payload['data_length']
if trainTrial is None or data_length < self.train_epoch[1] + event_inner_idx:
algo_log(f"SSMVEP epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{trainTrial.shape}event{trainTrial[-2, event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, event_inner_idx + self.train_epoch[
0]:event_inner_idx + self.train_epoch[1]]
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
self.trainLabel, list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"SSMVEP训练集{np.shape(self.trainData)}", level="DEBUG")
elif self.zmqServer.state_mode == 'predict': # 测试状态
if self.load_model == False: # 模型尚未训练完成
@@ -313,18 +310,22 @@ class Decoder_main(threading.Thread):
now = datetime.now()
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.zmqServer.event_inner_idx:
# algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG")
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
time.sleep(0.0001)
return
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
data = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
data_length = epoch_payload['data_length']
if data is None or data_length < self.interval_epoch[1] + event_inner_idx:
algo_log(f"SSMVEP predict epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{data.shape}, event: {data[-2, event_inner_idx]}", level="DEBUG")
data = self.preprocess(data[:self.n_chan, :]) # 预处理
data = data[:,
self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]]
pad_eeg_test = np.zeros(
(data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate'])))
pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data
@@ -384,28 +385,33 @@ class Decoder_main(threading.Thread):
'''训练阶段采集数据'''
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx:
self.currentLabel = self.zmqServer.currentLabel # 同步当前标签
algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
# algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
else:
time.sleep(0.001)
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
time.sleep(0.0001)
return
originalTrial = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
self.currentLabel = epoch_payload['currentLabel']
data_length = epoch_payload['data_length']
if originalTrial is None or data_length < self.zmqServer.train_epoch[1] + event_inner_idx:
algo_log(f"epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"训练队列数据:{data_length}", level="DEBUG")
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, event_inner_idx]}", level="DEBUG")
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
trainTrial = trainTrial[:, event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]]
# algo_log(f"trial: {event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
list) \
and self.trainLabel.count(self.currentLabel) < self.single_train:
self.trainData.append(trainTrial)
self.trainLabel.append(self.currentLabel)
algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG")
self.plotData.append(originalTrial[:self.n_chan, event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]])
self.plotLabel.append(self.currentLabel)
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
if self.zmqServer.StartDecode:
@@ -414,21 +420,26 @@ class Decoder_main(threading.Thread):
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
algo_log(f"MI启动预测 {formatted_time}", level="DEBUG")
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
self.interval_epoch[1] \
+ self.zmqServer.event_inner_idx:
try:
epoch_payload = self.zmqServer.epoch_queue.get_nowait()
except queue.Empty:
time.sleep(0.001)
return
originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
originalData = epoch_payload['snapshot']
event_inner_idx = epoch_payload['event_inner_idx']
data_length = epoch_payload['data_length']
if originalData is None or data_length < self.interval_epoch[1] + event_inner_idx:
algo_log(f"predict epoch数据长度不足: {data_length}, 跳过", level="WARNING")
return
algo_log(f"取出的:{originalData.shape},event: {originalData[-2, event_inner_idx]}", level="DEBUG")
start = time.time()
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
data = data[:,
self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]]
self.plotData.append(
originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]])
originalData[:self.n_chan, event_inner_idx + self.interval_epoch[
0]:event_inner_idx + self.interval_epoch[1]])
test_data = data[np.newaxis, np.newaxis, :, :]
test_data = torch.from_numpy(test_data)