Compare commits
31 Commits
694321b52c
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| c27e250fad | |||
| 66c0b71b89 | |||
| 5c7b73b7a4 | |||
| 9690971f43 | |||
| 5a5f103ef6 | |||
| b31bb18dfe | |||
| 38480a2ca3 | |||
| 62e7cab5be | |||
|
|
b26ae2ce3c | ||
| 5488626112 | |||
| d59b0f695f | |||
| 0570d41439 | |||
| 4574798d86 | |||
| d480107b37 | |||
| 2d70fc9956 | |||
|
|
1bbe84eb56 | ||
|
|
f21367bc20 | ||
| ba4ae92647 | |||
| 43adc6fb42 | |||
| b329989181 | |||
| 68106d8aed | |||
|
|
506ebfd973 | ||
|
|
5a2cc82100 | ||
|
|
81a8d78ab2 | ||
| 73e01782df | |||
| b78e583bec | |||
| 504e89ee47 | |||
|
|
a9dbe7261b | ||
| 7b5f4f6eb9 | |||
| 0cffd1ae02 | |||
| 0e5e79fcdd |
11
.gitignore
vendored
11
.gitignore
vendored
@@ -2,10 +2,14 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
|
release/
|
||||||
build/
|
build/
|
||||||
dist/
|
dist/
|
||||||
|
dist_nuitka/
|
||||||
# Environments
|
upperHost_stim/
|
||||||
|
.vscode/
|
||||||
|
#!upperHost_stim/MI_headless.py
|
||||||
|
#!upperHost_stim/ssmvep_headless.py
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
env/
|
env/
|
||||||
@@ -24,7 +28,8 @@ venv.bak/
|
|||||||
*.xlsx
|
*.xlsx
|
||||||
*.mat
|
*.mat
|
||||||
*.json
|
*.json
|
||||||
|
*.txt
|
||||||
|
*.pth
|
||||||
|
|
||||||
# PyCharm
|
# PyCharm
|
||||||
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
|
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
|
||||||
|
|||||||
40
Decoder.py
40
Decoder.py
@@ -14,8 +14,8 @@ from torch.autograd import Variable
|
|||||||
# from Device.SunnyLinker import SunnyLinker64
|
# from Device.SunnyLinker import SunnyLinker64
|
||||||
from SSMVEP.algorithm.tdca import TDCA
|
from SSMVEP.algorithm.tdca import TDCA
|
||||||
from SSMVEP.algorithm.base import generate_cca_references
|
from SSMVEP.algorithm.base import generate_cca_references
|
||||||
from concentration.algorithm.calculate_focus import Calculate
|
# from concentration.algorithm.calculate_focus import Calculate
|
||||||
from blinkdetection.algorithm.eye_detection import blink_detection
|
# from blinkdetection.algorithm.eye_detection import blink_detection
|
||||||
from Zmq.zmqServer import zmqServer
|
from Zmq.zmqServer import zmqServer
|
||||||
from Zmq.zmqClient import zmqClient
|
from Zmq.zmqClient import zmqClient
|
||||||
from MI.Algorithm.conformer_2class import onlineTrain
|
from MI.Algorithm.conformer_2class import onlineTrain
|
||||||
@@ -62,6 +62,8 @@ class Decoder_main(threading.Thread):
|
|||||||
|
|
||||||
# 注册滤波结果回调(示例:打印数据形状)
|
# 注册滤波结果回调(示例:打印数据形状)
|
||||||
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
|
self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data
|
||||||
|
# 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机
|
||||||
|
self.sliding_filter.set_beta_broadcast_callback(lambda v: self.zmqServer.broadcast_message('beta_psd', v))
|
||||||
|
|
||||||
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
||||||
# data: (chans, samples)
|
# data: (chans, samples)
|
||||||
@@ -76,7 +78,7 @@ class Decoder_main(threading.Thread):
|
|||||||
:return:
|
:return:
|
||||||
'''
|
'''
|
||||||
self.decoder_class = decoder_class
|
self.decoder_class = decoder_class
|
||||||
if decoder_class == 'ssvep' or decoder_class == 'pvs':
|
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||||
self.n_chan = 8
|
self.n_chan = 8
|
||||||
# self.thread_data_server.interval_inited = False
|
# self.thread_data_server.interval_inited = False
|
||||||
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
||||||
@@ -112,8 +114,8 @@ class Decoder_main(threading.Thread):
|
|||||||
elif decoder_class == 'mi' or decoder_class == 'ma':
|
elif decoder_class == 'mi' or decoder_class == 'ma':
|
||||||
self.zmqServer.interval_init(decoder_class)
|
self.zmqServer.interval_init(decoder_class)
|
||||||
self.n_chan = 21
|
self.n_chan = 21
|
||||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度4s,# 精确到小数点后6位
|
||||||
self.single_train = 40 # 单类别数量
|
self.single_train = 40 # 单类别数量
|
||||||
self.num_target = 2 # 分类目标数目
|
self.num_target = 2 # 分类目标数目
|
||||||
|
|
||||||
@@ -155,8 +157,8 @@ class Decoder_main(threading.Thread):
|
|||||||
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
|
# self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band')
|
||||||
|
|
||||||
def parameter_init(self,bandPass_low,bandPass_high):
|
def parameter_init(self,bandPass_low,bandPass_high):
|
||||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息 ssmvep [50, 550]
|
||||||
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch
|
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch ssmevep [50, 575]
|
||||||
self.trainData = [] #训练数据
|
self.trainData = [] #训练数据
|
||||||
self.trainLabel = [] #训练标签
|
self.trainLabel = [] #训练标签
|
||||||
self.plotData = [] #报告分析数据
|
self.plotData = [] #报告分析数据
|
||||||
@@ -204,6 +206,9 @@ class Decoder_main(threading.Thread):
|
|||||||
self.zmqServer.state_mode = 'rest'
|
self.zmqServer.state_mode = 'rest'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if self.zmqServer.open_Impedance:
|
||||||
|
time.sleep(0.005)
|
||||||
|
continue
|
||||||
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||||
self.decoder_SSVEP()
|
self.decoder_SSVEP()
|
||||||
elif self.decoder_class == 'ssmvep':
|
elif self.decoder_class == 'ssmvep':
|
||||||
@@ -213,7 +218,7 @@ class Decoder_main(threading.Thread):
|
|||||||
else:
|
else:
|
||||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
continue;
|
continue
|
||||||
self.zmqServer.paradigmBuffer.getData(25)
|
self.zmqServer.paradigmBuffer.getData(25)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"Decoder Loop Error: {e}")
|
algo_log(f"Decoder Loop Error: {e}")
|
||||||
@@ -231,7 +236,7 @@ class Decoder_main(threading.Thread):
|
|||||||
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
|
if self.zmqServer.open_Impedance: # 阻抗检测状态不解码
|
||||||
return
|
return
|
||||||
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
|
data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50)
|
||||||
algo_log(f"SSVEP取出的:{data.shape}, data = {data[:20]}", level="DEBUG")
|
# algo_log(f"SSVEP取出的:{data.shape}, data = {data[:, :10]}", level="DEBUG")
|
||||||
data = data[:self.n_chan, :]
|
data = data[:self.n_chan, :]
|
||||||
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
||||||
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
@@ -254,7 +259,7 @@ class Decoder_main(threading.Thread):
|
|||||||
def decoder_SSMVEP(self):
|
def decoder_SSMVEP(self):
|
||||||
'''模型训练'''
|
'''模型训练'''
|
||||||
if self.load_model == False and all(
|
if self.load_model == False and all(
|
||||||
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
|
self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练完成
|
||||||
self.trainData = np.array(self.trainData)
|
self.trainData = np.array(self.trainData)
|
||||||
self.trainLabel = np.array(self.trainLabel)
|
self.trainLabel = np.array(self.trainLabel)
|
||||||
algo_log(f"开始SSMVEP模型训练,数据形状:{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG")
|
algo_log(f"开始SSMVEP模型训练,数据形状:{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG")
|
||||||
@@ -284,6 +289,7 @@ class Decoder_main(threading.Thread):
|
|||||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
self.trainData.append(trainTrial)
|
self.trainData.append(trainTrial)
|
||||||
self.trainLabel.append(self.currentLabel)
|
self.trainLabel.append(self.currentLabel)
|
||||||
|
algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG")
|
||||||
else:
|
else:
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
@@ -301,6 +307,7 @@ class Decoder_main(threading.Thread):
|
|||||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||||
self.interval_epoch[1] \
|
self.interval_epoch[1] \
|
||||||
+ self.zmqServer.event_inner_idx:
|
+ self.zmqServer.event_inner_idx:
|
||||||
|
# algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG")
|
||||||
time.sleep(0.0001)
|
time.sleep(0.0001)
|
||||||
return
|
return
|
||||||
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
|
data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据
|
||||||
@@ -327,11 +334,11 @@ class Decoder_main(threading.Thread):
|
|||||||
def decoder_MI(self):
|
def decoder_MI(self):
|
||||||
'''模型训练'''
|
'''模型训练'''
|
||||||
if self.train_started == False and all(
|
if self.train_started == False and all(
|
||||||
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
|
self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练
|
||||||
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||||
self.train_started = True
|
self.train_started = True
|
||||||
self.trainData = np.array(self.trainData)
|
self.trainData = np.array(self.trainData)
|
||||||
self.trainLabel = np.array(self.trainLabel) + 1
|
self.trainLabel = np.array(self.trainLabel)
|
||||||
algo_log(f"MI开始训练,训练集:{np.shape(self.trainData)},标签shape:{np.shape(self.trainLabel)}", level="DEBUG")
|
algo_log(f"MI开始训练,训练集:{np.shape(self.trainData)},标签shape:{np.shape(self.trainLabel)}", level="DEBUG")
|
||||||
if save_train_data == 1:
|
if save_train_data == 1:
|
||||||
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
@@ -369,14 +376,15 @@ class Decoder_main(threading.Thread):
|
|||||||
'''训练阶段采集数据'''
|
'''训练阶段采集数据'''
|
||||||
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
|
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
|
||||||
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||||
self.interval_epoch[1] + self.zmqServer.event_inner_idx:
|
self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||||
|
self.currentLabel = self.zmqServer.currentLabel # 同步当前标签
|
||||||
algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
|
algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG")
|
||||||
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据
|
||||||
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG")
|
||||||
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[
|
||||||
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]
|
||||||
algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
|
# algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG")
|
||||||
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||||
list) \
|
list) \
|
||||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
@@ -421,9 +429,9 @@ class Decoder_main(threading.Thread):
|
|||||||
y_pred = torch.max(Cls, 1)[1]
|
y_pred = torch.max(Cls, 1)[1]
|
||||||
self.plotLabel.append(int(y_pred.item()))
|
self.plotLabel.append(int(y_pred.item()))
|
||||||
algo_log(f"MI运动意图识别: {y_pred}")
|
algo_log(f"MI运动意图识别: {y_pred}")
|
||||||
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
|
self.zmqServer.broadcast_message('result', int(y_pred.item()))
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。')
|
||||||
else: # 休息状态
|
else: # 休息状态
|
||||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ cudnn.benchmark = True
|
|||||||
cudnn.deterministic = True
|
cudnn.deterministic = True
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
# writer = SummaryWriter('./TensorBoardX/')
|
# writer = SummaryWriter('./TensorBoardX/')
|
||||||
|
from logs.log import algo_log
|
||||||
|
|
||||||
# Convolution module
|
# Convolution module
|
||||||
# use conv to capture local features, instead of postion embedding.
|
# use conv to capture local features, instead of postion embedding.
|
||||||
@@ -318,11 +318,7 @@ class ExP():
|
|||||||
train_pred = torch.max(outputs, 1)[1]
|
train_pred = torch.max(outputs, 1)[1]
|
||||||
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
||||||
|
|
||||||
print('Epoch:', e,
|
algo_log(f"Epoch = {e}, Train loss = {loss.detach().cpu().numpy():.6f}, Test loss = {loss_test.detach().cpu().numpy():.6f}, Train accuracy = {train_acc:.6f}, Test accuracy = {acc:.6f}", level="debug")
|
||||||
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
|
||||||
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
|
||||||
' Train accuracy %.6f' % train_acc,
|
|
||||||
' Test accuracy is %.6f' % acc)
|
|
||||||
|
|
||||||
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
||||||
num = num + 1
|
num = num + 1
|
||||||
@@ -335,8 +331,8 @@ class ExP():
|
|||||||
|
|
||||||
torch.save(self.model, model_path)
|
torch.save(self.model, model_path)
|
||||||
averAcc = averAcc / num
|
averAcc = averAcc / num
|
||||||
print('The average accuracy is:', averAcc)
|
algo_log(f"The average accuracy is: {averAcc}", level="debug")
|
||||||
print('The best accuracy is:', bestAcc)
|
algo_log(f"The best accuracy is: {bestAcc}", level="debug")
|
||||||
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
||||||
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
||||||
|
|
||||||
@@ -346,10 +342,10 @@ class ExP():
|
|||||||
|
|
||||||
def onlineTrain(data_queue,result_queue):
|
def onlineTrain(data_queue,result_queue):
|
||||||
import torch
|
import torch
|
||||||
print(f"[DEBUG] torch.__version__ = {torch.__version__}")
|
algo_log(f"[DEBUG] torch.__version__ = {torch.__version__}", level="debug")
|
||||||
print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}")
|
algo_log(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}", level="debug")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}")
|
algo_log(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}", level="debug")
|
||||||
try:
|
try:
|
||||||
starttime = datetime.datetime.now()
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
@@ -366,12 +362,13 @@ def onlineTrain(data_queue,result_queue):
|
|||||||
data = data_queue.get(timeout=30)
|
data = data_queue.get(timeout=30)
|
||||||
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
||||||
exp = ExP(n_chan)
|
exp = ExP(n_chan)
|
||||||
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
|
algo_log(f"训练参数: {np.shape(all_data)}, {np.shape(all_label)}, {model_path}", level="debug")
|
||||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
||||||
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
algo_log(f"THE BEST ACCURACY IS {str(bestAcc)}", level="debug")
|
||||||
|
|
||||||
endtime = datetime.datetime.now()
|
endtime = datetime.datetime.now()
|
||||||
print('train duration: ',str(endtime - starttime))
|
algo_log(f"train duration: {endtime - starttime}", level="debug")
|
||||||
|
|
||||||
|
|
||||||
# 将模型或参数传回
|
# 将模型或参数传回
|
||||||
result_queue.put({
|
result_queue.put({
|
||||||
@@ -387,7 +384,7 @@ def offlineTrain(all_data,all_label,modelPath):
|
|||||||
|
|
||||||
# seed_n = np.random.randint(2025)
|
# seed_n = np.random.randint(2025)
|
||||||
seed_n = 1877
|
seed_n = 1877
|
||||||
print('seed is ' + str(seed_n))
|
algo_log(f"seed is {seed_n}", level="debug")
|
||||||
random.seed(seed_n)
|
random.seed(seed_n)
|
||||||
np.random.seed(seed_n)
|
np.random.seed(seed_n)
|
||||||
torch.manual_seed(seed_n)
|
torch.manual_seed(seed_n)
|
||||||
@@ -397,13 +394,12 @@ def offlineTrain(all_data,all_label,modelPath):
|
|||||||
exp = ExP()
|
exp = ExP()
|
||||||
|
|
||||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
||||||
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
|
||||||
|
|
||||||
endtime = datetime.datetime.now()
|
endtime = datetime.datetime.now()
|
||||||
print('train duration: ',str(endtime - starttime))
|
algo_log(f"train duration: {endtime - starttime}", level="debug")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(time.asctime(time.localtime(time.time())))
|
algo_log(f"[DEBUG] time.asctime(time.localtime(time.time())) = {time.asctime(time.localtime(time.time()))}", level="debug")
|
||||||
print(time.asctime(time.localtime(time.time())))
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from einops import rearrange
|
|||||||
from einops.layers.torch import Rearrange, Reduce
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
from torch.backends import cudnn
|
from torch.backends import cudnn
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
from logs.log import algo_log
|
||||||
# writer = SummaryWriter('./TensorBoardX/')
|
# writer = SummaryWriter('./TensorBoardX/')
|
||||||
|
|
||||||
|
|
||||||
@@ -190,7 +191,7 @@ class ExP():
|
|||||||
# 自动选择设备:有 GPU 用 GPU,否则用 CPU
|
# 自动选择设备:有 GPU 用 GPU,否则用 CPU
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
# self.device = torch.device("cpu")
|
# self.device = torch.device("cpu")
|
||||||
print(f"Using device: {self.device}")
|
algo_log(f"Using device: {self.device}", level="debug")
|
||||||
|
|
||||||
# 定义张量类型(不再强制使用 cuda)
|
# 定义张量类型(不再强制使用 cuda)
|
||||||
self.Tensor = torch.FloatTensor
|
self.Tensor = torch.FloatTensor
|
||||||
|
|||||||
21
README.md
21
README.md
@@ -13,14 +13,31 @@ Debug_64ch_Decoder_Optimize is an updated version that fixes several issues and
|
|||||||
6. decoder class切换问题
|
6. decoder class切换问题
|
||||||
7. decoder_class切换时,数据重置、各类参数重置
|
7. decoder_class切换时,数据重置、各类参数重置
|
||||||
|
|
||||||
|
# realease log
|
||||||
|
- 2026年6月11日11:29:17 打包第一版,包名runDecoder.dist_v0.0.0_beta_20260611.7z
|
||||||
|
- 2026年6月11日12:00:00 打包第二版,包名runDecoder.dist_v0.0.0_beta_20260611.7z
|
||||||
|
- 修复上位机先发decoder_class, 后发open_impedence 带来decoder_main thread 阻塞问题
|
||||||
|
|
||||||
|
- 2026年6月12日15:05:47 runDecoder.dist_v0.0.2_beta_20260612
|
||||||
|
- 优化filter读数精度
|
||||||
|
|
||||||
# 常用命令
|
# 常用命令
|
||||||
source activate 3in1Py310
|
source activate 3in1Py310
|
||||||
python runDecoder.py
|
python runDecoder.py
|
||||||
python datamock.py
|
python datamock.py
|
||||||
python ZeroMQClient_mock.py
|
python ZeroMQClient_mock.py
|
||||||
python system_test.py
|
python filter_test.py
|
||||||
|
python upperHost_stimmock/MI_headless.py
|
||||||
|
|
||||||
|
# 打包命令
|
||||||
|
./nuitka_3in1_package.sh
|
||||||
|
|
||||||
# 遗留问题
|
# TODO
|
||||||
1. mvep是否要把list freq 开放到config
|
1. mvep是否要把list freq 开放到config
|
||||||
|
2. 滤波器参数 放到config文件
|
||||||
|
|
||||||
|
# debug log
|
||||||
|
## MI
|
||||||
|
Epoch采集完成|收到命令: {'method': 'train'|取出的
|
||||||
|
|
||||||
|
收到命令: {'method': 'train'|收到命令: {'method': 'train'|收到命令: {'method': 'predict'|事件检测到
|
||||||
@@ -12,16 +12,17 @@ from scipy.io import loadmat
|
|||||||
from scipy.linalg import qr
|
from scipy.linalg import qr
|
||||||
from scipy.signal import filtfilt, lfilter
|
from scipy.signal import filtfilt, lfilter
|
||||||
# from numpy.linalg import _umath_linalg
|
# from numpy.linalg import _umath_linalg
|
||||||
|
from logs.log import algo_log
|
||||||
|
|
||||||
|
|
||||||
class FbccaDw:
|
class FbccaDw:
|
||||||
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
||||||
print('******************************************')
|
algo_log('******************************************', level="debug")
|
||||||
print('parameter list')
|
algo_log('parameter list',level="debug")
|
||||||
print('target:', num_target)
|
algo_log(f"target: {num_target}", level="debug")
|
||||||
print('number of filter bank:', num_filter)
|
algo_log(f"number of filter bank: {num_filter}", level="debug")
|
||||||
print('parameter:', parameter)
|
algo_log(f"parameter: {parameter}", level="debug")
|
||||||
print('width:', width)
|
algo_log(f"width: {width}", level="debug")
|
||||||
self.phase = 0
|
self.phase = 0
|
||||||
self.bandWidth = width
|
self.bandWidth = width
|
||||||
self.winNum = winNum
|
self.winNum = winNum
|
||||||
@@ -237,7 +238,7 @@ class FbccaDw:
|
|||||||
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
||||||
return np.asmatrix(dataFiltered)
|
return np.asmatrix(dataFiltered)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(Exception)
|
algo_log(f"Exception: {Exception}", level="debug")
|
||||||
|
|
||||||
'''
|
'''
|
||||||
getDataQ
|
getDataQ
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class Beta_Calculate():
|
|||||||
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
|
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
|
||||||
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
|
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
|
||||||
|
|
||||||
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
|
# print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
|
||||||
|
|
||||||
return beta_psd, alpha_psd, theta_psd
|
return beta_psd, alpha_psd, theta_psd
|
||||||
|
|
||||||
|
|||||||
@@ -89,7 +89,8 @@ def zero_mq_client(server_address="tcp://127.0.0.1:8099"):
|
|||||||
{"method": "train", "params": 1},
|
{"method": "train", "params": 1},
|
||||||
{"method": "rest", "params": 0},
|
{"method": "rest", "params": 0},
|
||||||
{"method": "predict", "params": 1},
|
{"method": "predict", "params": 1},
|
||||||
{"method": "getReport", "params": 0}
|
{"method": "getReport", "params": 0},
|
||||||
|
{"method": "targetFreqs", "params": [11, 12, 13]}
|
||||||
]
|
]
|
||||||
|
|
||||||
# 打印消息集
|
# 打印消息集
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class ParadigmRingBuffer:
|
|||||||
def appendBuffer(self, data):
|
def appendBuffer(self, data):
|
||||||
if self.nUpdate == self.n_points:
|
if self.nUpdate == self.n_points:
|
||||||
# raise Exception("Buffer is full")
|
# raise Exception("Buffer is full")
|
||||||
algo_log("Buffer is full", record_once=True)
|
algo_log("ParadigmRingBuffer is full", record_once=True)
|
||||||
|
|
||||||
n = data.shape[1]
|
n = data.shape[1]
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,13 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
import queue
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
from logs.log import algo_log
|
from logs.log import algo_log
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from Tools.beta_calculate import Beta_Calculate
|
||||||
|
|
||||||
class FilterRingBuffer:
|
class FilterRingBuffer:
|
||||||
def __init__(self, n_chan, n_points):
|
def __init__(self, n_chan, n_points):
|
||||||
@@ -89,7 +94,74 @@ class FilterRingBuffer:
|
|||||||
self.has_new_data = False # 重置时清空新数据标记
|
self.has_new_data = False # 重置时清空新数据标记
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# 2. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
# 2. 独立 Beta PSD 计算线程(避免阻塞滤波主循环的 200ms 定时)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class BetaPsdCalculator(threading.Thread):
|
||||||
|
"""独立的 Beta PSD 计算线程,使用队列与滤波主线程解耦"""
|
||||||
|
|
||||||
|
def __init__(self, fs=250, window_size=750):
|
||||||
|
super().__init__(daemon=True)
|
||||||
|
self.fs = fs
|
||||||
|
self.window_size = window_size
|
||||||
|
self._beta_calc = Beta_Calculate(Threshold_value_low=0, Threshold_value_high=0, fs=fs)
|
||||||
|
self._input_queue = queue.Queue(maxsize=2)
|
||||||
|
self._running = threading.Event()
|
||||||
|
self._running.set()
|
||||||
|
self._latest_beta = None
|
||||||
|
self._beta_lock = threading.Lock()
|
||||||
|
self.beta_broadcast_callback = None
|
||||||
|
|
||||||
|
def push_data(self, data):
|
||||||
|
"""供外部调用的线程安全数据推送接口"""
|
||||||
|
try:
|
||||||
|
self._input_queue.put_nowait(data)
|
||||||
|
except queue.Full:
|
||||||
|
try:
|
||||||
|
self._input_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self._input_queue.put_nowait(data)
|
||||||
|
except queue.Full:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_latest_beta(self):
|
||||||
|
"""获取最新的 beta 值(线程安全)"""
|
||||||
|
with self._beta_lock:
|
||||||
|
return self._latest_beta
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while self._running.is_set():
|
||||||
|
try:
|
||||||
|
data = self._input_queue.get(timeout=1.5)
|
||||||
|
if data is None:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
beta_psd, _, _ = self._beta_calc.calculate_all(
|
||||||
|
data, fs=self.fs, nperseg=min(self.window_size, data.shape[1])
|
||||||
|
)
|
||||||
|
with self._beta_lock:
|
||||||
|
self._latest_beta = round(float(beta_psd), 3)
|
||||||
|
if self.beta_broadcast_callback is not None:
|
||||||
|
self.beta_broadcast_callback(self._latest_beta)
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"Beta PSD 计算异常: {e}", level='error')
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止计算线程"""
|
||||||
|
self._running.clear()
|
||||||
|
try:
|
||||||
|
self._input_queue.put_nowait(None)
|
||||||
|
except queue.Full:
|
||||||
|
pass
|
||||||
|
if self.is_alive():
|
||||||
|
self.join(timeout=2)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# 3. 独立滑动滤波类(仅负责滤波业务逻辑,不关心缓存实现)
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class SlidingFilter(threading.Thread):
|
class SlidingFilter(threading.Thread):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -118,24 +190,42 @@ class SlidingFilter(threading.Thread):
|
|||||||
# 滤波结果回调(外部可注册,获取滤波后的数据)
|
# 滤波结果回调(外部可注册,获取滤波后的数据)
|
||||||
self.filter_result_callback = None
|
self.filter_result_callback = None
|
||||||
|
|
||||||
|
# beta 每秒触发计数(200ms步长,5次 = 1s)
|
||||||
|
self._beta_step_counter = 0
|
||||||
|
self._beta_steps_per_second = max(1, int(round(1.0 / step_sec))) # 5
|
||||||
|
|
||||||
|
self.slide_window = None # 滑动窗口缓存 (n_chan, window_size)
|
||||||
|
self.slide_ready = False # 窗口是否已填满初始数据
|
||||||
# 预计算滤波器系数(仅执行一次)
|
# 预计算滤波器系数(仅执行一次)
|
||||||
self._init_filters()
|
self._init_filters()
|
||||||
|
|
||||||
|
# 独立的 Beta 计算线程(避免阻塞滤波主循环)
|
||||||
|
self._beta_thread = BetaPsdCalculator(fs=srate, window_size=self.window_size)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""同时启动 Beta 计算线程和滤波主线程"""
|
||||||
|
self._beta_thread.start()
|
||||||
|
super().start()
|
||||||
|
|
||||||
|
def set_beta_broadcast_callback(self, callback):
|
||||||
|
"""注册 Beta PSD 广播回调函数"""
|
||||||
|
self._beta_thread.beta_broadcast_callback = callback
|
||||||
|
|
||||||
def _init_filters(self):
|
def _init_filters(self):
|
||||||
"""预计算所有滤波器系数(仅执行一次)"""
|
"""预计算所有滤波器系数(仅执行一次)"""
|
||||||
# 50Hz工频陷波(Q=30,工业标准)
|
# 50Hz工频陷波(Q=30,工业标准)
|
||||||
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
|
self.b_notch, self.a_notch = signal.iirnotch(50, 30, self.srate)
|
||||||
# 8~30Hz带通FIR(65阶,线性相位)
|
# 0.5~45Hz带通FIR(65阶,线性相位)
|
||||||
self.b_bp = signal.firwin(
|
self.b_bp = signal.firwin(
|
||||||
numtaps=65,
|
numtaps=65,
|
||||||
cutoff=[8/(self.srate/2), 30/(self.srate/2)],
|
cutoff=[0.5/(self.srate/2), 45/(self.srate/2)],
|
||||||
pass_zero=False,
|
pass_zero=False,
|
||||||
window='hamming'
|
window='hamming'
|
||||||
)
|
)
|
||||||
self.a_bp = np.array([1.0])
|
self.a_bp = np.array([1.0])
|
||||||
|
|
||||||
def _filter_window_data(self, window_data):
|
def _filter_window_data(self, window_data):
|
||||||
"""对3秒窗口数据执行滤波,返回无边界效应的200ms数据"""
|
"""对3秒窗口数据执行滤波,返回 (无边界效应的200ms数据, 完整3s滤波数据)"""
|
||||||
# 零相位滤波(无延迟,无边界效应)
|
# 零相位滤波(无延迟,无边界效应)
|
||||||
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
filtered = window_data - np.mean(window_data, axis=-1, keepdims=True)
|
||||||
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
filtered = signal.filtfilt(self.b_notch, self.a_notch, filtered, axis=-1)
|
||||||
@@ -146,39 +236,64 @@ class SlidingFilter(threading.Thread):
|
|||||||
start_idx = self.window_size - 2 * self.step_size
|
start_idx = self.window_size - 2 * self.step_size
|
||||||
end_idx = self.window_size - self.step_size
|
end_idx = self.window_size - self.step_size
|
||||||
output_data = filtered[:, start_idx:end_idx].copy()
|
output_data = filtered[:, start_idx:end_idx].copy()
|
||||||
return output_data
|
return output_data, filtered
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""线程主逻辑:精确200ms触发一次滤波"""
|
"""线程主逻辑:精确200ms触发一次滤波"""
|
||||||
interval = self.step_sec # 200ms = 0.2秒
|
interval = self.step_sec # 0.2s
|
||||||
next_run_time = time.perf_counter()
|
# 以启动时刻为绝对时间基准(核心改动)
|
||||||
while self.running.is_set():
|
base_time = time.perf_counter()
|
||||||
# 1. 精确定时等待
|
frame_count = 0 # 帧计数器,用于对齐时序
|
||||||
current_time = time.perf_counter()
|
|
||||||
if current_time < next_run_time:
|
|
||||||
time.sleep(next_run_time - current_time)
|
|
||||||
next_run_time += interval
|
|
||||||
else:
|
|
||||||
algo_log("滤波耗时超过200ms,定时偏移", level='debug')
|
|
||||||
next_run_time = time.perf_counter() + interval
|
|
||||||
|
|
||||||
# ========== 新增核心判断:无新数据则直接跳过 ==========
|
while self.running.is_set():
|
||||||
|
# 计算理论执行时刻:严格按帧序号 × 步长
|
||||||
|
expect_time = base_time + frame_count * interval
|
||||||
|
current_time = time.perf_counter()
|
||||||
|
|
||||||
|
# 精确定时等待
|
||||||
|
if current_time < expect_time:
|
||||||
|
time.sleep(expect_time - current_time)
|
||||||
|
else:
|
||||||
|
# 处理超时:仅告警,不重置基准(防止累积偏移)
|
||||||
|
algo_log(f"滤波任务超时,偏移 {(current_time - expect_time)*1000:.1f} ms", level='debug')
|
||||||
|
|
||||||
|
frame_count += 1 # 帧序号自增,保证周期绝对稳定
|
||||||
if not self.ring_buffer.check_and_clear_new_data():
|
if not self.ring_buffer.check_and_clear_new_data():
|
||||||
# 无新数据,不执行滤波、不发送数据
|
# 无新数据,不执行滤波、不发送数据
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 2. 有新数据,才执行原有滤波逻辑
|
# ========== 原有滤波逻辑 ==========
|
||||||
try:
|
try:
|
||||||
window_data = self.ring_buffer.get_latest_n_points(self.window_size)
|
if not self.slide_ready:
|
||||||
if window_data is None:
|
# 阶段1:首次填满3s初始窗口
|
||||||
algo_log(f"缓存数据不足,当前缓存{self.ring_buffer.GetDataLenCount()}点,需{self.window_size}点", level='debug')
|
full_data = self.ring_buffer.get_latest_n_points(self.window_size)
|
||||||
|
if full_data is None:
|
||||||
|
algo_log("初始窗口数据不足", level='debug')
|
||||||
continue
|
continue
|
||||||
|
self.slide_window = full_data
|
||||||
|
self.slide_ready = True
|
||||||
|
else:
|
||||||
|
# 阶段2:正常滑动 → 取最新50个新点,增量拼接
|
||||||
|
new_step_data = self.ring_buffer.get_latest_n_points(self.step_size)
|
||||||
|
if new_step_data is None:
|
||||||
|
algo_log("滑动步长数据不足", level='debug')
|
||||||
|
continue
|
||||||
|
# 增量滑动:丢弃前50点,拼接新50点(标准滑动窗口)
|
||||||
|
self.slide_window = np.hstack([
|
||||||
|
self.slide_window[:, self.step_size:],
|
||||||
|
new_step_data
|
||||||
|
])
|
||||||
|
|
||||||
filtered_data = self._filter_window_data(window_data)
|
filtered_data, filtered_full = self._filter_window_data(self.slide_window[:64, :])
|
||||||
# algo_log(f"滤波后{filtered_data.shape}数据", level='debug')
|
|
||||||
|
# Beta PSD 每秒计算一次
|
||||||
|
self._beta_step_counter += 1
|
||||||
|
if self._beta_step_counter >= self._beta_steps_per_second:
|
||||||
|
self._beta_step_counter = 0
|
||||||
|
self._beta_thread.push_data(filtered_full[:2, :])
|
||||||
|
|
||||||
if self.filter_result_callback is not None:
|
if self.filter_result_callback is not None:
|
||||||
self.filter_result_callback(filtered_data[:64, :])
|
self.filter_result_callback(filtered_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"滤波执行异常: {e}", level='error')
|
algo_log(f"滤波执行异常: {e}", level='error')
|
||||||
|
|
||||||
@@ -187,17 +302,11 @@ class SlidingFilter(threading.Thread):
|
|||||||
self.filter_result_callback = callback
|
self.filter_result_callback = callback
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""停止滤波线程(安全版)"""
|
"""停止滤波线程和 Beta 计算线程"""
|
||||||
# 1. 先设置停止标志(Event.clear()是线程安全的)
|
self._beta_thread.stop()
|
||||||
self.running.clear()
|
self.running.clear()
|
||||||
|
|
||||||
# 2. 核心修复:只有线程已启动且正在运行时才调用join
|
|
||||||
if self.is_alive():
|
if self.is_alive():
|
||||||
# 等待线程正常退出,最多1秒
|
|
||||||
self.join(timeout=1)
|
self.join(timeout=1)
|
||||||
# 超时未退出时打印警告,便于排查问题
|
|
||||||
if self.is_alive():
|
if self.is_alive():
|
||||||
algo_log("警告:滤波线程在1秒内未正常退出,可能存在阻塞操作", level="WARNING")
|
algo_log("警告:滤波线程在1秒内未正常退出,可能存在阻塞操作", level="WARNING")
|
||||||
|
|
||||||
# 3. 无论线程是否启动,都打印停止日志
|
|
||||||
algo_log("滤波线程已停止")
|
algo_log("滤波线程已停止")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import ast
|
import ast
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import threading
|
import threading
|
||||||
|
import zmq
|
||||||
import json
|
import json
|
||||||
import queue
|
import queue
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
@@ -13,17 +14,14 @@ from Zmq.filterProcess import FilterRingBuffer
|
|||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
from logs.log import algo_log
|
from logs.log import algo_log
|
||||||
|
|
||||||
import zmq
|
zmqServer_host = str(IniRead('system', 'zmqServer_host', '127.0.0.1'))
|
||||||
|
|
||||||
class zmqServer(threading.Thread):
|
class zmqServer(threading.Thread):
|
||||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.device_info = device_info
|
self.device_info = device_info
|
||||||
|
|
||||||
self.host = host
|
self.host = zmqServer_host
|
||||||
|
|
||||||
# test_host = "10.200.27.140"
|
|
||||||
# self.host = test_host
|
|
||||||
|
|
||||||
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
||||||
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
||||||
@@ -96,6 +94,7 @@ class zmqServer(threading.Thread):
|
|||||||
self.pack_contain_event = False
|
self.pack_contain_event = False
|
||||||
self.event_inner_idx = -1
|
self.event_inner_idx = -1
|
||||||
self.interval_inited = False
|
self.interval_inited = False
|
||||||
|
self.last_epoch_finish_time = None
|
||||||
|
|
||||||
def reset_state(self):
|
def reset_state(self):
|
||||||
"""清空采集器状态和缓存数据"""
|
"""清空采集器状态和缓存数据"""
|
||||||
@@ -119,11 +118,11 @@ class zmqServer(threading.Thread):
|
|||||||
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点
|
self.train_latency = (self.train_epoch[1] + 0.1 * self.device_info['sample_rate']) // 5 #120包 600个点
|
||||||
|
|
||||||
elif decoder_class == 'mi':
|
elif decoder_class == 'mi':
|
||||||
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||||
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch]
|
self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in interval_epoch] #[125, 1125]
|
||||||
self.train_epoch = self.interval_epoch.copy()
|
self.train_epoch = self.interval_epoch.copy()
|
||||||
self.latency = self.interval_epoch[1] // 5
|
self.latency = self.interval_epoch[1] // 5 #225
|
||||||
self.train_latency = self.latency
|
self.train_latency = self.latency #225
|
||||||
|
|
||||||
algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO")
|
algo_log(f"时间窗初始化完成: {interval_epoch}", level="INFO")
|
||||||
self.count_events: Dict[str, int] = {}
|
self.count_events: Dict[str, int] = {}
|
||||||
@@ -153,6 +152,7 @@ class zmqServer(threading.Thread):
|
|||||||
msg = {'method': method, 'params': params}
|
msg = {'method': method, 'params': params}
|
||||||
msg_bytes = json.dumps(msg).encode('utf-8')
|
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||||
|
|
||||||
|
if msg['method'] != 'beta_psd':
|
||||||
algo_log(f"发送命令结果: {msg}", level="DEBUG")
|
algo_log(f"发送命令结果: {msg}", level="DEBUG")
|
||||||
|
|
||||||
# 广播到所有命令客户端
|
# 广播到所有命令客户端
|
||||||
@@ -180,7 +180,7 @@ class zmqServer(threading.Thread):
|
|||||||
# 转置为上位机需要的[50, 通道数]格式
|
# 转置为上位机需要的[50, 通道数]格式
|
||||||
filtered_data = filtered_data.T.astype(np.float64)
|
filtered_data = filtered_data.T.astype(np.float64)
|
||||||
send_buf = filtered_data.tobytes()
|
send_buf = filtered_data.tobytes()
|
||||||
algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
|
# algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG", record_once=True)
|
||||||
self.data_send_queue.put(send_buf)
|
self.data_send_queue.put(send_buf)
|
||||||
|
|
||||||
def _process_data_send_queue(self):
|
def _process_data_send_queue(self):
|
||||||
@@ -197,6 +197,7 @@ class zmqServer(threading.Thread):
|
|||||||
b"",
|
b"",
|
||||||
send_buf
|
send_buf
|
||||||
])
|
])
|
||||||
|
algo_log(f"发送滤波数据成功,长度: {len(send_buf)}字节", level="DEBUG", record_once=True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"发送滤波数据失败: {e}", level="ERROR")
|
algo_log(f"发送滤波数据失败: {e}", level="ERROR")
|
||||||
@@ -225,6 +226,9 @@ class zmqServer(threading.Thread):
|
|||||||
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
|
algo_log(f"无效JSON命令: {message_bytes.hex()}", level="ERROR")
|
||||||
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
|
self.broadcast_message("error", {"code": 400, "message": "无效JSON格式"})
|
||||||
return
|
return
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"_handle_cmd_message exception: {e}", level="ERROR")
|
||||||
|
return
|
||||||
|
|
||||||
algo_log(f"收到命令: {message}", level="INFO")
|
algo_log(f"收到命令: {message}", level="INFO")
|
||||||
method = message.get("method")
|
method = message.get("method")
|
||||||
@@ -270,6 +274,22 @@ class zmqServer(threading.Thread):
|
|||||||
elif params == 2: #停止解码
|
elif params == 2: #停止解码
|
||||||
self.IsExitApp = True
|
self.IsExitApp = True
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"method": "predict_response",
|
||||||
|
"params": {
|
||||||
|
"code": 200,
|
||||||
|
"message": "ok"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
resp_bytes = json.dumps(resp, ensure_ascii=False).encode("utf-8")
|
||||||
|
self.cmd_socket.send_multipart([ident, b"", resp_bytes])
|
||||||
|
algo_log(f"predict 命令已即时回复客户端 {ident}", level="DEBUG")
|
||||||
|
except Exception as e:
|
||||||
|
algo_log(f"predict 命令回复失败: {e}", level="ERROR")
|
||||||
|
return
|
||||||
|
|
||||||
elif method == "rest":
|
elif method == "rest":
|
||||||
self.state_mode = 'rest'
|
self.state_mode = 'rest'
|
||||||
elif method == "impedance":
|
elif method == "impedance":
|
||||||
@@ -291,6 +311,8 @@ class zmqServer(threading.Thread):
|
|||||||
elif len(frames) == 3:
|
elif len(frames) == 3:
|
||||||
# 标准格式
|
# 标准格式
|
||||||
ident, empty_sep, data_bytes = frames[:3]
|
ident, empty_sep, data_bytes = frames[:3]
|
||||||
|
elif len(frames) == 2:
|
||||||
|
ident, data_bytes = frames[:2]
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
# 注册新的数据客户端(单客户端场景,自动覆盖旧身份)
|
# 注册新的数据客户端(单客户端场景,自动覆盖旧身份)
|
||||||
@@ -322,8 +344,22 @@ class zmqServer(threading.Thread):
|
|||||||
if self.pack_contain_event:
|
if self.pack_contain_event:
|
||||||
self.paradigmBuffer.resetAllPara()
|
self.paradigmBuffer.resetAllPara()
|
||||||
self.paradigmBuffer.appendBuffer(data_np)
|
self.paradigmBuffer.appendBuffer(data_np)
|
||||||
|
|
||||||
if self.epoch_finished:
|
if self.epoch_finished:
|
||||||
algo_log('Epoch采集完成: ' + datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3], level="DEBUG")
|
now = datetime.datetime.now()
|
||||||
|
time_diff_str = ""
|
||||||
|
# 计算与上一次Epoch完成的时间差
|
||||||
|
if self.last_epoch_finish_time is not None:
|
||||||
|
# 时间差 单位:秒,保留3位小数
|
||||||
|
delta_seconds = (now - self.last_epoch_finish_time).total_seconds()
|
||||||
|
time_diff_str = f" | 与上一次间隔: {delta_seconds:.3f} s"
|
||||||
|
|
||||||
|
# 拼接日志,增加时间差信息
|
||||||
|
log_msg = f"Epoch采集完成: {now.strftime('%H:%M:%S.%f')[:-3]}{time_diff_str}"
|
||||||
|
algo_log(log_msg, level="DEBUG")
|
||||||
|
|
||||||
|
# 更新上一次Epoch完成时间为当前时间
|
||||||
|
self.last_epoch_finish_time = now
|
||||||
else:
|
else:
|
||||||
self.paradigmBuffer.appendBuffer(data_np)
|
self.paradigmBuffer.appendBuffer(data_np)
|
||||||
|
|
||||||
@@ -337,8 +373,8 @@ class zmqServer(threading.Thread):
|
|||||||
def detect_event(self, samples):
|
def detect_event(self, samples):
|
||||||
self.pack_contain_event = False
|
self.pack_contain_event = False
|
||||||
# 第65通道为事件通道
|
# 第65通道为事件通道
|
||||||
event = int(samples[-2][0])
|
events = np.array(samples[-2], dtype=np.int32).tolist()
|
||||||
# for idx, event in enumerate(events):
|
for idx, event in enumerate(events):
|
||||||
if event in self.events:
|
if event in self.events:
|
||||||
new_key = "".join(
|
new_key = "".join(
|
||||||
[
|
[
|
||||||
@@ -352,8 +388,8 @@ class zmqServer(threading.Thread):
|
|||||||
self.count_events[new_key] = self.latency + 1
|
self.count_events[new_key] = self.latency + 1
|
||||||
else:
|
else:
|
||||||
self.count_events[new_key] = self.train_latency + 1
|
self.count_events[new_key] = self.train_latency + 1
|
||||||
self.event_inner_idx = self.device_info['frame_points'] - 1
|
self.event_inner_idx = idx
|
||||||
# algo_log(f"事件检测到: {event},索引: {idx}", level="DEBUG")
|
algo_log(f"事件检测到: {events},索引: {idx}", level="DEBUG")
|
||||||
self.pack_contain_event = True
|
self.pack_contain_event = True
|
||||||
|
|
||||||
# 倒计时并清理过期事件
|
# 倒计时并清理过期事件
|
||||||
@@ -389,13 +425,18 @@ class zmqServer(threading.Thread):
|
|||||||
frames = self.cmd_socket.recv_multipart()
|
frames = self.cmd_socket.recv_multipart()
|
||||||
self._handle_cmd_message(frames)
|
self._handle_cmd_message(frames)
|
||||||
|
|
||||||
# 处理8100数据端口消息
|
# 处理8100数据端口消息(排空积压,消除标签延迟)
|
||||||
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
|
||||||
frames = self.data_socket.recv_multipart()
|
while True:
|
||||||
|
try:
|
||||||
|
frames = self.data_socket.recv_multipart(zmq.NOBLOCK)
|
||||||
self._handle_data_message(frames)
|
self._handle_data_message(frames)
|
||||||
|
except zmq.Again:
|
||||||
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
algo_log(f"服务器主循环异常: {e}", level="ERROR")
|
algo_log(f"服务器主循环异常: {str(e)}", level="ERROR")
|
||||||
|
return
|
||||||
finally:
|
finally:
|
||||||
self.running = False
|
self.running = False
|
||||||
# 优雅关闭所有资源
|
# 优雅关闭所有资源
|
||||||
|
|||||||
19
config.ini
19
config.ini
@@ -15,10 +15,25 @@ Audio_device = 0
|
|||||||
Rest_time = 2
|
Rest_time = 2
|
||||||
Upper_Host = 127.0.0.1
|
Upper_Host = 127.0.0.1
|
||||||
Upper_Port = 8088
|
Upper_Port = 8088
|
||||||
|
Decoder_Host = 127.0.0.1
|
||||||
|
Decoder_Port = 8099
|
||||||
Serial_port = COM44
|
Serial_port = COM44
|
||||||
algo_log_level = DEBUG
|
|
||||||
console_output = 1
|
|
||||||
save_train_data = 0
|
save_train_data = 0
|
||||||
|
zmqServer_host = 127.0.0.1
|
||||||
|
|
||||||
|
[algo_log]
|
||||||
|
# ========== 文件日志配置 ==========
|
||||||
|
file_log_enable = true
|
||||||
|
file_log_level = DEBUG
|
||||||
|
log_path = exe
|
||||||
|
retention_days = 3
|
||||||
|
|
||||||
|
# ========== 控制台/黑框配置 ==========
|
||||||
|
console_enable = true
|
||||||
|
console_show_window = true
|
||||||
|
console_log_level = DEBUG
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
; 64 导设备配置
|
; 64 导设备配置
|
||||||
[device_type_1]
|
[device_type_1]
|
||||||
|
|||||||
74
datamock.py
74
datamock.py
@@ -11,8 +11,8 @@ N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
|
|||||||
EEG_FREQ = 10 # EEG 正弦波频率 Hz
|
EEG_FREQ = 10 # EEG 正弦波频率 Hz
|
||||||
EEG_AMP = 100.0 # EEG 幅值 100μV
|
EEG_AMP = 100.0 # EEG 幅值 100μV
|
||||||
LABEL_INTERVAL = 5 # 标签间隔秒数
|
LABEL_INTERVAL = 5 # 标签间隔秒数
|
||||||
# SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
||||||
SERVER_ADDR = 'tcp://10.200.27.140:8100'
|
LABEL_CMD_ADDR = 'tcp://127.0.0.1:8101' # 接收来自上位机范式的标签命令
|
||||||
|
|
||||||
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
|
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
|
||||||
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
|
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
|
||||||
@@ -67,9 +67,41 @@ def main():
|
|||||||
sock.connect(SERVER_ADDR)
|
sock.connect(SERVER_ADDR)
|
||||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}")
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}")
|
||||||
|
|
||||||
|
# ========== 上位机标签命令监听 ==========
|
||||||
|
# 使用线程安全的队列接收来自 ssmvep_main.py 的标签命令
|
||||||
|
# 标签值: 1 (train 0), 2 (train 1), 99 (predict)
|
||||||
|
pending_label = [None] # [label_value or None]
|
||||||
|
label_lock = threading.Lock()
|
||||||
|
|
||||||
|
label_cmd_sock = ctx.socket(zmq.PULL)
|
||||||
|
label_cmd_sock.bind(LABEL_CMD_ADDR)
|
||||||
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听绑定到 {LABEL_CMD_ADDR}")
|
||||||
|
|
||||||
|
stop_recv = threading.Event()
|
||||||
|
|
||||||
|
def label_cmd_thread():
|
||||||
|
"""监听来自上位机范式的标签命令,写入 pending_label"""
|
||||||
|
while not stop_recv.is_set():
|
||||||
|
try:
|
||||||
|
msg = label_cmd_sock.recv_string(zmq.NOBLOCK)
|
||||||
|
label_val = int(msg)
|
||||||
|
with label_lock:
|
||||||
|
pending_label[0] = label_val
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S')
|
||||||
|
label_name = {1: 'train_0', 2: 'train_1', 99: 'predict'}.get(label_val, str(label_val))
|
||||||
|
print(f"[{ts}] 收到标签命令: {label_name} -> label={label_val}")
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.005)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[label_cmd_thread] 错误: {e}")
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
label_thread = threading.Thread(target=label_cmd_thread, daemon=True)
|
||||||
|
label_thread.start()
|
||||||
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听线程已启动")
|
||||||
|
|
||||||
# 后台消费线程:持续 recv 从 ROUTER 返回的数据,避免 server 发送队列积压
|
# 后台消费线程:持续 recv 从 ROUTER 返回的数据,避免 server 发送队列积压
|
||||||
recv_count = [0]
|
recv_count = [0]
|
||||||
stop_recv = threading.Event()
|
|
||||||
|
|
||||||
def consumer_thread():
|
def consumer_thread():
|
||||||
"""消费线程:阻塞 recv,丢弃收到的数据,仅用于清空 ROUTER 发送队列"""
|
"""消费线程:阻塞 recv,丢弃收到的数据,仅用于清空 ROUTER 发送队列"""
|
||||||
@@ -98,7 +130,7 @@ def main():
|
|||||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...")
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...")
|
||||||
print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms")
|
print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms")
|
||||||
print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV")
|
print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV")
|
||||||
print(f" 标签: 每 {LABEL_INTERVAL}s 末尾采样点触发 | label 1/2 交替")
|
print(f" 标签: 来自上位机范式命令 (train_0=1, train_1=2, predict=99)")
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -108,30 +140,21 @@ def main():
|
|||||||
# 构建当前包
|
# 构建当前包
|
||||||
packet = build_packet(global_sample_idx)
|
packet = build_packet(global_sample_idx)
|
||||||
|
|
||||||
# 检查是否需要放置标签
|
# 检查是否有来自上位机范式的挂起标签命令
|
||||||
if should_send_label(global_sample_idx):
|
with label_lock:
|
||||||
if label_type == 1:
|
ext_label = pending_label[0]
|
||||||
label1_count += 1
|
if ext_label is not None:
|
||||||
label_value = 1
|
pending_label[0] = None
|
||||||
label_number = label1_count
|
|
||||||
else:
|
|
||||||
label2_count += 1
|
|
||||||
label_value = 2
|
|
||||||
label_number = label2_count
|
|
||||||
|
|
||||||
# 标签放在当前包最后一个采样点(索引 4)
|
|
||||||
packet[4, 64] = label_value
|
|
||||||
packet[4, 65] = label_number
|
|
||||||
|
|
||||||
|
if ext_label is not None:
|
||||||
|
# 将标签写入当前包所有5个采样点的第65通道 (index 64)
|
||||||
|
# 覆盖全部采样点确保 event_inner_idx 无论落在哪个位置都能被正确检测
|
||||||
|
packet[:, 64] = float(ext_label)
|
||||||
ts = datetime.now().strftime('%H:%M:%S')
|
ts = datetime.now().strftime('%H:%M:%S')
|
||||||
print(f"[{ts}] 标签触发: label={label_value}, 序号={label_number} "
|
print(f"[{ts}] 打标签: label={ext_label} -> ch64[all 5 samples] (global_sample_idx={global_sample_idx})")
|
||||||
f"(global_sample_idx={global_sample_idx})")
|
|
||||||
|
|
||||||
# 交替标签类型
|
# 发送: multipart 2帧 ['', data]
|
||||||
label_type = 2 if label_type == 1 else 1
|
# 使用标准格式,ROUTER 会自动附加 ZMQ 分配的客户端身份
|
||||||
|
|
||||||
# 发送: multipart 3帧 [identity, '', data]
|
|
||||||
# 使用标准格式(3帧),ROUTER 会自动附加 ZMQ 分配的客户端身份
|
|
||||||
sock.send_multipart([
|
sock.send_multipart([
|
||||||
b'',
|
b'',
|
||||||
packet.tobytes()
|
packet.tobytes()
|
||||||
@@ -156,6 +179,7 @@ def main():
|
|||||||
finally:
|
finally:
|
||||||
stop_recv.set()
|
stop_recv.set()
|
||||||
consumer.join(timeout=2)
|
consumer.join(timeout=2)
|
||||||
|
label_cmd_sock.close()
|
||||||
sock.close()
|
sock.close()
|
||||||
ctx.term()
|
ctx.term()
|
||||||
|
|
||||||
|
|||||||
230
filter_test.py
230
filter_test.py
@@ -1,11 +1,13 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
脑电滤波服务 8100端口测试工具【最终修复版】
|
脑电滤波服务 8100端口测试工具【统计逻辑专项优化版】
|
||||||
修复:1. Matplotlib中文字体乱码 2. ZMQ双连接收不到数据问题
|
优化点:
|
||||||
通信规范:
|
1. 5秒预热(250个发包),预热结束后才启动丢包/数据统计
|
||||||
上位机 -> 服务端:send_multipart([client_id, b"", data_buf]) 共3帧
|
2. 业务比例:0.02s发1包,200ms收1包 → 每 10 个发包对应 1 个回包
|
||||||
服务端 recv_multipart() 帧长度 = 3
|
3. 通道校验:发送(5,66) 仅对比前64通道,接收(50,64)全通道比对
|
||||||
时序:每20ms(0.02s)发送一包 (5,66),服务端200ms回传 (50,64)
|
4. 区分:全局总包数 / 有效统计区间包数、理论收包数、实际收包数、丢包数、丢包率
|
||||||
|
5. 新增64通道整体数据均值/极值比对,校验数据有效性
|
||||||
|
通信规范:send_multipart([client_id, b"", data_buf]) 三帧报文,服务端 recv_multipart 长度=3
|
||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@@ -20,33 +22,41 @@ from matplotlib.animation import FuncAnimation
|
|||||||
|
|
||||||
# ===================== 全局前置:修复Matplotlib中文字体 & 负号显示 =====================
|
# ===================== 全局前置:修复Matplotlib中文字体 & 负号显示 =====================
|
||||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "WenQuanYi Micro Hei"]
|
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "WenQuanYi Micro Hei"]
|
||||||
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示异常
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
|
|
||||||
# ===================== 【1. 全局可配置参数区】 =====================
|
# ===================== 【1. 全局业务固定参数(核心统计规则)】 =====================
|
||||||
# ZMQ 服务端配置
|
# ZMQ 服务端配置
|
||||||
ZMQ_SERVER_IP = "127.0.0.1"
|
ZMQ_SERVER_IP = "127.0.0.1"
|
||||||
ZMQ_SERVER_PORT = 8100
|
ZMQ_SERVER_PORT = 8100
|
||||||
ZMQ_SOCKET_TIMEOUT = 3000 # 套接字超时(ms)
|
ZMQ_SOCKET_TIMEOUT = 3000 # 套接字超时(ms)
|
||||||
POLL_TIMEOUT = 10 # Poll轮询超时(ms),不影响发包时序
|
POLL_TIMEOUT = 10 # Poll轮询超时(ms)
|
||||||
|
|
||||||
# 数据报文配置(严格对齐业务)
|
# 时序 & 统计核心规则(严格对齐现场业务)
|
||||||
PKG_SEND_SHAPE = (5, 66) # 发送包 shape (点数, 总通道)
|
SEND_INTERVAL = 0.02 # 上位机发包间隔:20ms/包
|
||||||
PKG_RECV_SHAPE = (50, 64) # 滤波回包 shape (点数, 脑电通道)
|
RECV_INTERVAL = 0.2 # 服务端回包间隔:200ms/包
|
||||||
SEND_INTERVAL = 0.02 # 上位机发包间隔 20ms
|
PREHEAT_SECONDS = 5.0 # 滤波缓存预热时长:5秒
|
||||||
SAMPLE_RATE = 250 # 采样率 Hz
|
# 计算:预热需要的发包总数 = 预热时长 / 单包发送间隔
|
||||||
|
PREHEAT_SEND_PACKS = int(PREHEAT_SECONDS / SEND_INTERVAL) # 5 / 0.02 = 250 包
|
||||||
|
# 收发包比例:每多少个发包对应1个回包
|
||||||
|
PACK_RATIO = int(RECV_INTERVAL / SEND_INTERVAL) # 0.2 / 0.02 = 10
|
||||||
|
|
||||||
# 通道定义
|
# 数据报文形状
|
||||||
CH_EEG = 64
|
PKG_SEND_SHAPE = (5, 66) # 发送包 (点数, 总通道)
|
||||||
|
PKG_RECV_SHAPE = (50, 64) # 回包 (点数, 有效脑电通道)
|
||||||
|
SAMPLE_RATE = 250
|
||||||
|
|
||||||
|
# 通道定义(对比仅使用前64路脑电通道)
|
||||||
|
CH_EEG_VALID = 64 # 共同对比通道数:0~63
|
||||||
CH_EVENT = 64
|
CH_EVENT = 64
|
||||||
CH_RESERVED = 65
|
CH_RESERVED = 65
|
||||||
|
|
||||||
# ZMQ 三帧报文固定字段(和你服务端代码完全一致)
|
# ZMQ 三帧报文固定字段
|
||||||
CLIENT_ID = b"test_client_001"
|
CLIENT_ID = b"test_client_001"
|
||||||
EMPTY_FRAME = b""
|
EMPTY_FRAME = b""
|
||||||
|
|
||||||
# 仿真信号配置(可自由调参测试滤波)
|
# 仿真信号配置
|
||||||
TARGET_CHANNEL = 0
|
TARGET_CHANNEL = 0
|
||||||
SIGNAL_FREQ_LIST = [10.0, 22.0]
|
SIGNAL_FREQ_LIST = [13]
|
||||||
SIGNAL_AMP = 1.8
|
SIGNAL_AMP = 1.8
|
||||||
NOISE_GAUSSIAN_AMP = 0.4
|
NOISE_GAUSSIAN_AMP = 0.4
|
||||||
NOISE_POWER50_AMP = 0.3
|
NOISE_POWER50_AMP = 0.3
|
||||||
@@ -64,21 +74,32 @@ MAX_RUN_SECONDS = None
|
|||||||
ENABLE_RECONNECT = True
|
ENABLE_RECONNECT = True
|
||||||
PRINT_STAT_INTERVAL = 5.0
|
PRINT_STAT_INTERVAL = 5.0
|
||||||
|
|
||||||
# ===================== 【2. 全局变量 & 线程安全】 =====================
|
# ===================== 【2. 全局变量 + 统计结构体(重构统计逻辑)】 =====================
|
||||||
g_running = threading.Event()
|
g_running = threading.Event()
|
||||||
g_running.set()
|
g_running.set()
|
||||||
data_lock = threading.Lock()
|
data_lock = threading.Lock()
|
||||||
|
|
||||||
# 绘图数据缓冲区
|
# 绘图缓冲区
|
||||||
raw_data_buf = deque(maxlen=MAX_PLOT_POINTS)
|
raw_data_buf = deque(maxlen=MAX_PLOT_POINTS)
|
||||||
filt_data_buf = deque(maxlen=MAX_PLOT_POINTS)
|
filt_data_buf = deque(maxlen=MAX_PLOT_POINTS)
|
||||||
|
|
||||||
# 运行统计
|
# ===================== 全新统计变量(区分预热/正式统计) =====================
|
||||||
stat = {
|
stat = {
|
||||||
"send_cnt": 0,
|
# 全局总包数(包含预热包)
|
||||||
"recv_cnt": 0,
|
"total_send": 0,
|
||||||
|
"total_recv": 0,
|
||||||
|
|
||||||
|
# 有效统计区间(预热250包之后)
|
||||||
|
"valid_send": 0, # 有效发包数
|
||||||
|
"valid_recv": 0, # 有效收包数
|
||||||
|
"theo_recv": 0, # 理论应收到包数 = valid_send // PACK_RATIO
|
||||||
|
|
||||||
|
# 运行时间
|
||||||
"start_time": time.perf_counter(),
|
"start_time": time.perf_counter(),
|
||||||
"last_print_time": time.perf_counter()
|
"last_print_time": time.perf_counter(),
|
||||||
|
|
||||||
|
# 数据校验缓存:保存最新一包原始64通道数据,用于和回包比对
|
||||||
|
"latest_raw_64ch": None
|
||||||
}
|
}
|
||||||
|
|
||||||
# ===================== 【3. 日志配置】 =====================
|
# ===================== 【3. 日志配置】 =====================
|
||||||
@@ -95,74 +116,67 @@ logger = init_logger()
|
|||||||
|
|
||||||
# ===================== 【4. 仿真脑电数据生成 (5,66)】 =====================
|
# ===================== 【4. 仿真脑电数据生成 (5,66)】 =====================
|
||||||
def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
|
def generate_eeg_packet(pkt_idx: int) -> np.ndarray:
|
||||||
"""生成单包 (5,66) 仿真数据:脑电+噪声+工频+事件通道+保留通道"""
|
"""生成单包 (5,66) 仿真数据"""
|
||||||
n_point, n_chan = PKG_SEND_SHAPE
|
n_point, n_chan = PKG_SEND_SHAPE
|
||||||
base_t = pkt_idx * n_point / SAMPLE_RATE
|
base_t = pkt_idx * n_point / SAMPLE_RATE
|
||||||
t_arr = base_t + np.arange(n_point) / SAMPLE_RATE
|
t_arr = base_t + np.arange(n_point) / SAMPLE_RATE
|
||||||
|
|
||||||
data = np.zeros((n_point, n_chan), dtype=np.float64)
|
data = np.zeros((n_point, n_chan), dtype=np.float64)
|
||||||
|
|
||||||
# 64路脑电:多频信号 + 50Hz工频 + 高斯白噪声
|
# 64路脑电信号
|
||||||
for ch in range(CH_EEG):
|
for ch in range(CH_EEG_VALID):
|
||||||
sig = 0.0
|
sig = 0.0
|
||||||
for freq in SIGNAL_FREQ_LIST:
|
for freq in SIGNAL_FREQ_LIST:
|
||||||
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
|
sig += SIGNAL_AMP * np.sin(2 * np.pi * freq * t_arr)
|
||||||
sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
|
# sig += NOISE_POWER50_AMP * np.sin(2 * np.pi * 50 * t_arr)
|
||||||
sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
|
# sig += NOISE_GAUSSIAN_AMP * np.random.randn(n_point)
|
||||||
data[:, ch] = sig
|
data[:, ch] = sig
|
||||||
|
|
||||||
# 事件通道、保留通道赋值
|
# 事件通道、保留通道
|
||||||
data[:, CH_EVENT] = EVENT_LABEL_VAL
|
data[:, CH_EVENT] = EVENT_LABEL_VAL
|
||||||
data[:, CH_RESERVED] = RESERVED_VAL
|
data[:, CH_RESERVED] = RESERVED_VAL
|
||||||
return data
|
return data
|
||||||
|
|
||||||
# ===================== 【5. 核心修复:单DEALER连接 + Poller 同时收发】 =====================
|
# ===================== 【5. ZMQ 核心IO线程(单连接+Poller,保留原有通信逻辑)】 =====================
|
||||||
def zmq_io_thread():
|
def zmq_io_thread():
|
||||||
"""
|
|
||||||
唯一ZMQ工作线程:单个DEALER连接,同时发包+收包(对齐真实上位机)
|
|
||||||
使用 Poller 多路复用,避免阻塞、超时报错
|
|
||||||
"""
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
pkt_index = 0
|
pkt_index = 0
|
||||||
send_interval = SEND_INTERVAL
|
send_interval = SEND_INTERVAL
|
||||||
|
|
||||||
|
logger.info(f"滤波预热配置:{PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 个发包后开始统计")
|
||||||
|
logger.info(f"收发比例:每 {PACK_RATIO} 个发包 → 1 个滤波回包")
|
||||||
|
|
||||||
while g_running.is_set():
|
while g_running.is_set():
|
||||||
try:
|
try:
|
||||||
# 新建 DEALER 套接字(全局唯一连接)
|
|
||||||
sock = context.socket(zmq.DEALER)
|
sock = context.socket(zmq.DEALER)
|
||||||
sock.setsockopt(zmq.RCVTIMEO, ZMQ_SOCKET_TIMEOUT)
|
sock.setsockopt(zmq.RCVTIMEO, ZMQ_SOCKET_TIMEOUT)
|
||||||
sock.setsockopt(zmq.SNDTIMEO, ZMQ_SOCKET_TIMEOUT)
|
sock.setsockopt(zmq.SNDTIMEO, ZMQ_SOCKET_TIMEOUT)
|
||||||
sock.connect(f"tcp://{ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
sock.connect(f"tcp://{ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
||||||
logger.info(f"ZMQ 连接成功 -> {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
logger.info(f"ZMQ 连接成功 -> {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
||||||
|
|
||||||
# 注册Poller:监听当前套接字的可读事件
|
|
||||||
poller = zmq.Poller()
|
poller = zmq.Poller()
|
||||||
poller.register(sock, zmq.POLLIN)
|
poller.register(sock, zmq.POLLIN)
|
||||||
|
|
||||||
# 精准发包计时(消除sleep漂移)
|
|
||||||
next_send_ts = time.perf_counter()
|
next_send_ts = time.perf_counter()
|
||||||
|
|
||||||
while g_running.is_set():
|
while g_running.is_set():
|
||||||
# 1. 运行时长限制判断
|
# 全局运行时长限制
|
||||||
if MAX_RUN_SECONDS is not None:
|
if MAX_RUN_SECONDS is not None:
|
||||||
run_sec = time.perf_counter() - stat["start_time"]
|
run_sec = time.perf_counter() - stat["start_time"]
|
||||||
if run_sec > MAX_RUN_SECONDS:
|
if run_sec > MAX_RUN_SECONDS:
|
||||||
logger.info(f"已到达设定运行时长 {MAX_RUN_SECONDS}s,停止任务")
|
logger.info(f"已到达设定运行时长 {MAX_RUN_SECONDS}s,停止任务")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 2. Poll 轮询:有数据就接收,无数据继续执行发包逻辑
|
# ========== 1. 轮询接收服务端回包 ==========
|
||||||
socks_ready = dict(poller.poll(POLL_TIMEOUT))
|
socks_ready = dict(poller.poll(POLL_TIMEOUT))
|
||||||
if sock in socks_ready:
|
if sock in socks_ready:
|
||||||
# ========== 接收服务端回包 (multipart) ==========
|
|
||||||
frames = sock.recv_multipart()
|
frames = sock.recv_multipart()
|
||||||
if not frames:
|
if not frames:
|
||||||
continue
|
continue
|
||||||
# 取最后一帧为有效滤波数据
|
|
||||||
recv_bytes = frames[-1]
|
recv_bytes = frames[-1]
|
||||||
if not recv_bytes:
|
if not recv_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 解析为 (50,64) float64
|
# 解析回包 (50,64)
|
||||||
filt_data = np.frombuffer(recv_bytes, dtype=np.float64)
|
filt_data = np.frombuffer(recv_bytes, dtype=np.float64)
|
||||||
expect_size = PKG_RECV_SHAPE[0] * PKG_RECV_SHAPE[1]
|
expect_size = PKG_RECV_SHAPE[0] * PKG_RECV_SHAPE[1]
|
||||||
if filt_data.size != expect_size:
|
if filt_data.size != expect_size:
|
||||||
@@ -170,42 +184,89 @@ def zmq_io_thread():
|
|||||||
continue
|
continue
|
||||||
filt_data = filt_data.reshape(PKG_RECV_SHAPE)
|
filt_data = filt_data.reshape(PKG_RECV_SHAPE)
|
||||||
|
|
||||||
# 统计 + 写入绘图缓冲区
|
# 全局收包计数
|
||||||
stat["recv_cnt"] += 1
|
stat["total_recv"] += 1
|
||||||
|
|
||||||
|
# 仅预热完成后,计入有效统计收包
|
||||||
|
if stat["total_send"] > PREHEAT_SEND_PACKS:
|
||||||
|
stat["valid_recv"] += 1
|
||||||
|
|
||||||
|
# 写入绘图缓冲区
|
||||||
with data_lock:
|
with data_lock:
|
||||||
filt_data_buf.extend(filt_data[:, TARGET_CHANNEL])
|
filt_data_buf.extend(filt_data[:, TARGET_CHANNEL])
|
||||||
|
|
||||||
# 定时打印运行状态
|
# ---------- 新增:64通道数据比对(发包前64通道 <-> 回包64通道) ----------
|
||||||
now = time.perf_counter()
|
raw_64ch = stat["latest_raw_64ch"]
|
||||||
if now - stat["last_print_time"] > PRINT_STAT_INTERVAL:
|
if raw_64ch is not None:
|
||||||
run_sec = now - stat["start_time"]
|
raw_mean = np.mean(raw_64ch)
|
||||||
loss_rate = (stat["send_cnt"] - stat["recv_cnt"]) / stat["send_cnt"] * 100 if stat["send_cnt"] > 0 else 0.0
|
filt_mean = np.mean(filt_data)
|
||||||
logger.info(
|
raw_amp = np.max(np.abs(raw_64ch))
|
||||||
f"运行:{run_sec:.1f}s | 发包:{stat['send_cnt']} | 收包:{stat['recv_cnt']} | 丢包率:{loss_rate:.2f}%"
|
filt_amp = np.max(np.abs(filt_data))
|
||||||
|
logger.debug(
|
||||||
|
f"【通道数据比对】原始64通道均值:{raw_mean:.4f} 幅值:{raw_amp:.4f} | "
|
||||||
|
f"滤波后均值:{filt_mean:.4f} 幅值:{filt_amp:.4f}"
|
||||||
)
|
)
|
||||||
stat["last_print_time"] = now
|
|
||||||
|
|
||||||
# 3. 精准定时发包(严格20ms间隔)
|
# ========== 2. 精准定时发送数据包 ==========
|
||||||
current_ts = time.perf_counter()
|
current_ts = time.perf_counter()
|
||||||
if current_ts >= next_send_ts:
|
if current_ts >= next_send_ts:
|
||||||
# 生成 (5,66) 仿真数据包
|
# 生成(5,66)仿真包
|
||||||
pkt_data = generate_eeg_packet(pkt_index)
|
pkt_data = generate_eeg_packet(pkt_index)
|
||||||
pkt_index += 1
|
pkt_index += 1
|
||||||
send_buf = pkt_data.tobytes()
|
send_buf = pkt_data.tobytes()
|
||||||
|
|
||||||
# ========== 三帧Multipart发送(和你服务端代码完全一致) ==========
|
# 标准三帧Multipart发送
|
||||||
sock.send_multipart([CLIENT_ID, EMPTY_FRAME, send_buf])
|
sock.send_multipart([CLIENT_ID, EMPTY_FRAME, send_buf])
|
||||||
|
|
||||||
# 统计 + 写入原始数据缓冲区
|
# ---------- 发包计数逻辑(核心优化:预热区分) ----------
|
||||||
stat["send_cnt"] += 1
|
stat["total_send"] += 1
|
||||||
|
# 预热完成后,计入有效发包
|
||||||
|
if stat["total_send"] > PREHEAT_SEND_PACKS:
|
||||||
|
stat["valid_send"] += 1
|
||||||
|
# 计算理论应收包数
|
||||||
|
stat["theo_recv"] = stat["valid_send"] // PACK_RATIO
|
||||||
|
|
||||||
|
# 缓存当前包前64通道,用于后续数据比对
|
||||||
|
stat["latest_raw_64ch"] = pkt_data[:, :CH_EEG_VALID]
|
||||||
|
|
||||||
|
# 绘图缓冲区(单通道波形)
|
||||||
with data_lock:
|
with data_lock:
|
||||||
raw_data_buf.extend(pkt_data[:, TARGET_CHANNEL])
|
raw_data_buf.extend(pkt_data[:, TARGET_CHANNEL])
|
||||||
|
|
||||||
# 更新下一次发包时间戳
|
# 更新下一次发包时间
|
||||||
next_send_ts += send_interval
|
next_send_ts += send_interval
|
||||||
|
|
||||||
|
# ========== 3. 定时打印统计信息(区分预热/正式统计) ==========
|
||||||
|
now = time.perf_counter()
|
||||||
|
if now - stat["last_print_time"] > PRINT_STAT_INTERVAL:
|
||||||
|
run_sec = now - stat["start_time"]
|
||||||
|
total_send = stat["total_send"]
|
||||||
|
total_recv = stat["total_recv"]
|
||||||
|
|
||||||
|
# 分支1:仍在预热阶段
|
||||||
|
if total_send <= PREHEAT_SEND_PACKS:
|
||||||
|
remain = PREHEAT_SEND_PACKS - total_send
|
||||||
|
logger.info(
|
||||||
|
f"[预热中] 运行:{run_sec:.1f}s | 已发包:{total_send}/{PREHEAT_SEND_PACKS} | "
|
||||||
|
f"剩余预热包:{remain} | 暂不统计丢包"
|
||||||
|
)
|
||||||
|
# 分支2:预热完成,进入正式统计
|
||||||
|
else:
|
||||||
|
v_send = stat["valid_send"]
|
||||||
|
v_recv = stat["valid_recv"]
|
||||||
|
t_recv = stat["theo_recv"]
|
||||||
|
loss_cnt = t_recv - v_recv
|
||||||
|
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[正式统计] 运行:{run_sec:.1f}s | "
|
||||||
|
f"全局总包: 发{total_send}/收{total_recv} | "
|
||||||
|
f"有效区间: 发{v_send}/应收{t_recv}/实收{v_recv} | "
|
||||||
|
f"丢包数:{loss_cnt} | 丢包率:{loss_rate:.2f}%"
|
||||||
|
)
|
||||||
|
stat["last_print_time"] = now
|
||||||
|
|
||||||
except zmq.ZMQError as e:
|
except zmq.ZMQError as e:
|
||||||
# 区分正常超时 和 网络异常
|
|
||||||
if e.errno == zmq.EAGAIN:
|
if e.errno == zmq.EAGAIN:
|
||||||
continue
|
continue
|
||||||
logger.warning(f"ZMQ 连接异常: {e}")
|
logger.warning(f"ZMQ 连接异常: {e}")
|
||||||
@@ -222,7 +283,7 @@ def zmq_io_thread():
|
|||||||
context.term()
|
context.term()
|
||||||
logger.info("ZMQ IO 线程已退出")
|
logger.info("ZMQ IO 线程已退出")
|
||||||
|
|
||||||
# ===================== 【6. 可视化绘图(无逻辑改动,已前置修复字体)】 =====================
|
# ===================== 【6. 可视化绘图(无改动)】 =====================
|
||||||
def init_plot():
|
def init_plot():
|
||||||
fig = plt.figure(figsize=(14, 9))
|
fig = plt.figure(figsize=(14, 9))
|
||||||
fig.suptitle(f"脑电滤波测试 | 观测通道: {TARGET_CHANNEL}", fontsize=14)
|
fig.suptitle(f"脑电滤波测试 | 观测通道: {TARGET_CHANNEL}", fontsize=14)
|
||||||
@@ -264,7 +325,6 @@ def update_plot(frame, lines, axes):
|
|||||||
raw_data = list(raw_data_buf)
|
raw_data = list(raw_data_buf)
|
||||||
filt_data = list(filt_data_buf)
|
filt_data = list(filt_data_buf)
|
||||||
|
|
||||||
# 时域波形
|
|
||||||
if raw_data:
|
if raw_data:
|
||||||
x_raw = np.arange(len(raw_data))
|
x_raw = np.arange(len(raw_data))
|
||||||
line_raw.set_data(x_raw, raw_data)
|
line_raw.set_data(x_raw, raw_data)
|
||||||
@@ -277,7 +337,6 @@ def update_plot(frame, lines, axes):
|
|||||||
ax2.relim()
|
ax2.relim()
|
||||||
ax2.autoscale_view()
|
ax2.autoscale_view()
|
||||||
|
|
||||||
# 频谱计算(汉宁窗减少频谱泄露)
|
|
||||||
def calc_fft(sig, n_fft):
|
def calc_fft(sig, n_fft):
|
||||||
if len(sig) < n_fft:
|
if len(sig) < n_fft:
|
||||||
return [], []
|
return [], []
|
||||||
@@ -300,7 +359,7 @@ def update_plot(frame, lines, axes):
|
|||||||
|
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
# ===================== 【7. 资源释放 & 主入口】 =====================
|
# ===================== 【7. 资源释放 & 最终汇总统计】 =====================
|
||||||
def clean_resource():
|
def clean_resource():
|
||||||
g_running.clear()
|
g_running.clear()
|
||||||
logger.info("开始停止所有线程...")
|
logger.info("开始停止所有线程...")
|
||||||
@@ -309,18 +368,19 @@ def clean_resource():
|
|||||||
logger.info("资源释放完成")
|
logger.info("资源释放完成")
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 70)
|
||||||
logger.info("脑电滤波测试客户端 【修复版】启动")
|
logger.info("脑电滤波测试客户端【统计逻辑优化版】启动")
|
||||||
logger.info(f"服务端地址: {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
logger.info(f"服务端地址: {ZMQ_SERVER_IP}:{ZMQ_SERVER_PORT}")
|
||||||
logger.info(f"发包格式: {PKG_SEND_SHAPE} | 间隔: {SEND_INTERVAL*1000:.0f}ms")
|
logger.info(f"发包: {PKG_SEND_SHAPE}({SEND_INTERVAL*1000:.0f}ms) | 回包: {PKG_RECV_SHAPE}({RECV_INTERVAL*1000:.0f}ms)")
|
||||||
logger.info(f"回包格式: {PKG_RECV_SHAPE} | ZMQ三帧报文 [客户端ID, 空帧, 数据帧]")
|
logger.info(f"预热规则: {PREHEAT_SECONDS}秒 / {PREHEAT_SEND_PACKS} 包后开启统计")
|
||||||
logger.info("=" * 60)
|
logger.info(f"收发比例: 每 {PACK_RATIO} 个发包对应 1 个回包")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
# 启动唯一ZMQ收发线程
|
# 启动ZMQ收发线程
|
||||||
io_thread = threading.Thread(target=zmq_io_thread, daemon=True, name="ZMQ_IO_Thread")
|
io_thread = threading.Thread(target=zmq_io_thread, daemon=True, name="ZMQ_IO_Thread")
|
||||||
io_thread.start()
|
io_thread.start()
|
||||||
|
|
||||||
# 启动可视化绘图
|
# 启动可视化
|
||||||
fig, lines, axes = init_plot()
|
fig, lines, axes = init_plot()
|
||||||
ani = FuncAnimation(
|
ani = FuncAnimation(
|
||||||
fig, update_plot,
|
fig, update_plot,
|
||||||
@@ -330,20 +390,30 @@ def main():
|
|||||||
cache_frame_data=False
|
cache_frame_data=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 主线程阻塞,监听关闭
|
|
||||||
try:
|
try:
|
||||||
plt.show()
|
plt.show()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("收到 Ctrl+C 中断信号,准备退出")
|
logger.info("收到 Ctrl+C 中断信号,准备退出")
|
||||||
finally:
|
finally:
|
||||||
# 输出最终统计
|
# 输出最终完整汇总报表
|
||||||
run_total = time.perf_counter() - stat["start_time"]
|
run_total = time.perf_counter() - stat["start_time"]
|
||||||
loss_rate = (stat["send_cnt"] - stat["recv_cnt"]) / stat["send_cnt"] * 100 if stat["send_cnt"] > 0 else 0.0
|
total_send = stat["total_send"]
|
||||||
logger.info(f"\n===== 运行汇总 =====")
|
total_recv = stat["total_recv"]
|
||||||
|
v_send = stat["valid_send"]
|
||||||
|
v_recv = stat["valid_recv"]
|
||||||
|
t_recv = stat["theo_recv"]
|
||||||
|
|
||||||
|
loss_cnt = t_recv - v_recv
|
||||||
|
loss_rate = (loss_cnt / t_recv * 100) if t_recv > 0 else 0.0
|
||||||
|
|
||||||
|
logger.info(f"\n{'='*50} 最终运行汇总 {'='*50}")
|
||||||
logger.info(f"总运行时长: {run_total:.1f} s")
|
logger.info(f"总运行时长: {run_total:.1f} s")
|
||||||
logger.info(f"总发包数: {stat['send_cnt']}")
|
logger.info(f"【全局总包数】发送: {total_send} | 接收: {total_recv}")
|
||||||
logger.info(f"总收包数: {stat['recv_cnt']}")
|
logger.info(f"【有效统计区间(跳过预热{PREHEAT_SEND_PACKS}包)】")
|
||||||
logger.info(f"整体丢包率: {loss_rate:.2f} %")
|
logger.info(f" 有效发包: {v_send} | 理论应收包: {t_recv} | 实际收包: {v_recv}")
|
||||||
|
logger.info(f" 总丢包数: {loss_cnt} | 整体丢包率: {loss_rate:.2f} %")
|
||||||
|
logger.info(f"{'='*106}")
|
||||||
|
|
||||||
clean_resource()
|
clean_resource()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|||||||
150
logs/log.py
150
logs/log.py
@@ -1,114 +1,172 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
import inspect
|
import inspect
|
||||||
|
try:
|
||||||
|
import win32gui
|
||||||
|
import win32con
|
||||||
|
WIN32_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
WIN32_AVAILABLE = False
|
||||||
|
|
||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
# 全局配置
|
# ===================== 新增:获取 EXE 同级目录 =====================
|
||||||
console_output = IniRead('system', 'console_output', '1')
|
def get_app_root():
|
||||||
log_level = IniRead('system', 'algo_log_level', 'INFO')
|
"""获取 runDecoder.exe 所在的真实根目录(兼容 onefile / standalone)"""
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
# Nuitka / PyInstaller 打包后走这里
|
||||||
|
app_path = sys.executable
|
||||||
|
else:
|
||||||
|
# 本地源码运行时,取当前脚本目录
|
||||||
|
app_path = os.path.abspath(__file__)
|
||||||
|
return os.path.dirname(app_path)
|
||||||
|
|
||||||
|
# 程序根目录(exe 同级)
|
||||||
|
APP_ROOT = Path(get_app_root())
|
||||||
|
# 日志文件夹名:exe 同级下 logs 目录
|
||||||
|
DEFAULT_LOG_DIR = APP_ROOT / "logs"
|
||||||
|
|
||||||
|
# ===================== 读取 [algo_log] 配置 =====================
|
||||||
|
# 文件日志
|
||||||
|
FILE_LOG_ENABLE = IniRead("algo_log", "file_log_enable", "true").lower() == "true"
|
||||||
|
FILE_LOG_LEVEL = IniRead("algo_log", "file_log_level", "DEBUG").upper()
|
||||||
|
# 优先级:配置文件 > 默认exe同级logs
|
||||||
|
CFG_LOG_PATH = IniRead("algo_log", "log_path", "").strip()
|
||||||
|
if CFG_LOG_PATH == "exe":
|
||||||
|
LOG_DIR = DEFAULT_LOG_DIR
|
||||||
|
else:
|
||||||
|
LOG_DIR = Path(CFG_LOG_PATH)
|
||||||
|
|
||||||
|
LOG_RETENTION_DAYS = int(IniRead("algo_log", "retention_days", 3))
|
||||||
|
|
||||||
|
# 控制台日志 + 黑框控制
|
||||||
|
CONSOLE_ENABLE = IniRead("algo_log", "console_enable", "true").lower() == "true"
|
||||||
|
CONSOLE_SHOW_WINDOW = IniRead("algo_log", "console_show_window", "true").lower() == "true"
|
||||||
|
CONSOLE_LOG_LEVEL = IniRead("algo_log", "console_log_level", "INFO").upper()
|
||||||
|
|
||||||
|
# ===================== 全局常量与缓存 =====================
|
||||||
log_once_cache = set()
|
log_once_cache = set()
|
||||||
logger_cache = {}
|
logger_cache = {}
|
||||||
LOG_RETENTION_DAYS = 3
|
|
||||||
LOG_DIR = './logs/'
|
|
||||||
LOG_FILE_PREFIX = 'algo_log_'
|
LOG_FILE_PREFIX = 'algo_log_'
|
||||||
|
# 确保日志目录存在
|
||||||
|
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
LOG_DIR_STR = str(LOG_DIR) + "\\"
|
||||||
|
|
||||||
# 日志格式:时间 - 日志器名 - 级别 - 文件名:行号 - 函数名 - 日志内容
|
# 日志格式
|
||||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
|
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
|
||||||
|
|
||||||
|
# 日志级别映射
|
||||||
|
LEVEL_MAP = {
|
||||||
|
"DEBUG": logging.DEBUG,
|
||||||
|
"INFO": logging.INFO,
|
||||||
|
"WARNING": logging.WARNING,
|
||||||
|
"ERROR": logging.ERROR,
|
||||||
|
"FATAL": logging.FATAL
|
||||||
|
}
|
||||||
|
FILE_LOG_LEVEL_INT = LEVEL_MAP.get(FILE_LOG_LEVEL, logging.INFO)
|
||||||
|
CONSOLE_LOG_LEVEL_INT = LEVEL_MAP.get(CONSOLE_LOG_LEVEL, logging.INFO)
|
||||||
|
|
||||||
def clean_old_logs():
|
# ===================== Windows 控制台黑框显示/隐藏 =====================
|
||||||
"""清理超过指定天数的旧日志文件"""
|
def control_console_window():
|
||||||
|
if not sys.platform.startswith("win") or not WIN32_AVAILABLE:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
if not os.path.exists(LOG_DIR):
|
hwnd = win32gui.GetForegroundWindow()
|
||||||
|
if CONSOLE_SHOW_WINDOW:
|
||||||
|
win32gui.ShowWindow(hwnd, win32con.SW_SHOW)
|
||||||
|
else:
|
||||||
|
win32gui.ShowWindow(hwnd, win32con.SW_HIDE)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
control_console_window()
|
||||||
|
|
||||||
|
# ===================== 清理过期日志 =====================
|
||||||
|
def clean_old_logs():
|
||||||
|
try:
|
||||||
|
if not LOG_DIR.exists():
|
||||||
return
|
return
|
||||||
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
|
expire_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
|
||||||
for filename in os.listdir(LOG_DIR):
|
for filename in os.listdir(LOG_DIR):
|
||||||
if not filename.startswith(LOG_FILE_PREFIX) or not filename.endswith('.log'):
|
if not (filename.startswith(LOG_FILE_PREFIX) and filename.endswith('.log')):
|
||||||
continue
|
continue
|
||||||
date_str = filename[len(LOG_FILE_PREFIX):-4]
|
date_str = filename[len(LOG_FILE_PREFIX):-4]
|
||||||
try:
|
try:
|
||||||
file_date = datetime.strptime(date_str, '%Y-%m-%d')
|
file_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||||
if file_date < expire_date:
|
if file_date < expire_date:
|
||||||
file_path = os.path.join(LOG_DIR, filename)
|
file_path = LOG_DIR / filename
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
print(f"清理过期日志: {file_path}")
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"清理旧日志异常: {str(e)}")
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ===================== 初始化日志器 =====================
|
||||||
def init_module_logger(logger_name):
|
def init_module_logger(logger_name):
|
||||||
"""初始化日志器 + 清理旧日志"""
|
|
||||||
os.makedirs(LOG_DIR, exist_ok=True)
|
|
||||||
clean_old_logs()
|
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
log_file = os.path.join(LOG_DIR, f"{LOG_FILE_PREFIX}{current_date}.log")
|
|
||||||
|
|
||||||
if logger_name in logger_cache:
|
if logger_name in logger_cache:
|
||||||
return logger_cache[logger_name]
|
return logger_cache[logger_name]
|
||||||
|
|
||||||
|
clean_old_logs()
|
||||||
|
|
||||||
logger = logging.getLogger(logger_name)
|
logger = logging.getLogger(logger_name)
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(logging.DEBUG)
|
||||||
if logger.handlers:
|
if logger.handlers:
|
||||||
logger_cache[logger_name] = logger
|
logger_cache[logger_name] = logger
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
# 文件输出处理器
|
formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
|
||||||
|
|
||||||
|
# 文件日志
|
||||||
|
if FILE_LOG_ENABLE:
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
log_file = LOG_DIR / f"{LOG_FILE_PREFIX}{current_date}.log"
|
||||||
file_handler = RotatingFileHandler(
|
file_handler = RotatingFileHandler(
|
||||||
log_file,
|
log_file,
|
||||||
maxBytes=10 * 1024 * 1024,
|
maxBytes=10 * 1024 * 1024,
|
||||||
backupCount=10,
|
backupCount=10,
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)
|
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
file_handler.setLevel(FILE_LOG_LEVEL_INT)
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
# 控制台输出
|
# 控制台日志
|
||||||
if console_output:
|
if CONSOLE_ENABLE:
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
|
console_handler.setLevel(CONSOLE_LOG_LEVEL_INT)
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
logger_cache[logger_name] = logger
|
logger_cache[logger_name] = logger
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
# ===================== 对外日志入口函数 =====================
|
||||||
def algo_log(content, level="INFO", record_once=False):
|
def algo_log(content, level="INFO", record_once=False):
|
||||||
"""
|
frame = inspect.currentframe()
|
||||||
日志入口函数
|
if frame:
|
||||||
自动记录:调用文件名、代码行号、所在函数
|
frame = frame.f_back.f_back
|
||||||
"""
|
file_name = os.path.basename(frame.f_code.co_filename) if frame else "unknown"
|
||||||
# 回溯栈帧,获取真正调用 algo_log 的代码位置
|
|
||||||
# f_back(1) -> algo_log 自身,f_back(2) -> 业务调用处
|
|
||||||
frame = inspect.currentframe().f_back.f_back
|
|
||||||
if not frame:
|
|
||||||
file_name = "unknown"
|
|
||||||
else:
|
|
||||||
file_name = os.path.basename(frame.f_code.co_filename)
|
|
||||||
|
|
||||||
logger = init_module_logger(file_name)
|
logger = init_module_logger(file_name)
|
||||||
|
|
||||||
# 单次日志去重
|
|
||||||
if record_once:
|
if record_once:
|
||||||
log_key = f"{level.upper()}_{content}"
|
log_key = f"{level.upper()}_{content}"
|
||||||
if log_key in log_once_cache:
|
if log_key in log_once_cache:
|
||||||
return
|
return
|
||||||
log_once_cache.add(log_key)
|
log_once_cache.add(log_key)
|
||||||
|
|
||||||
# 日志级别分发
|
|
||||||
level_upper = level.upper()
|
level_upper = level.upper()
|
||||||
log_map = {
|
log_func_map = {
|
||||||
"DEBUG": logger.debug,
|
"DEBUG": logger.debug,
|
||||||
|
"INFO": logger.info,
|
||||||
"WARNING": logger.warning,
|
"WARNING": logger.warning,
|
||||||
"ERROR": logger.error,
|
"ERROR": logger.error,
|
||||||
"FATAL": logger.fatal,
|
"FATAL": logger.fatal
|
||||||
"INFO": logger.info
|
|
||||||
}
|
}
|
||||||
log_func = log_map.get(level_upper, logger.info)
|
log_func = log_func_map.get(level_upper, logger.info)
|
||||||
log_func(content)
|
log_func(content)
|
||||||
54
nuitka_3in1_package.sh
Normal file
54
nuitka_3in1_package.sh
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Git Bash 中文 UTF-8 兼容配置(通用版,无报错)
|
||||||
|
export LC_ALL=en_US.UTF-8
|
||||||
|
export LANG=en_US.UTF-8
|
||||||
|
|
||||||
|
echo "========================"
|
||||||
|
echo "Nuitka 打包脚本 - 优化稳定版"
|
||||||
|
echo "适配:PyTorch2.0.0 + CUDA11.7 + 脑电解码项目"
|
||||||
|
echo "========================"
|
||||||
|
|
||||||
|
# ===================== 自定义配置区 =====================
|
||||||
|
PY_FILE="runDecoder.py" # 主程序文件
|
||||||
|
OUT_DIR="dist_nuitka" # 输出文件夹
|
||||||
|
MODEL_DIR="online_Models" # 模型文件夹
|
||||||
|
# ========================================================
|
||||||
|
|
||||||
|
# 检查主文件是否存在
|
||||||
|
if [ ! -f "${PY_FILE}" ]; then
|
||||||
|
echo "错误:未找到主文件 ${PY_FILE},请检查路径!"
|
||||||
|
read -n 1 -s -r -p "按任意键退出"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "开始打包:${PY_FILE}"
|
||||||
|
echo "输出目录:${OUT_DIR}"
|
||||||
|
|
||||||
|
# Nuitka 核心打包命令(无错误、无冗余、全依赖)
|
||||||
|
python -m nuitka \
|
||||||
|
--standalone \
|
||||||
|
--msvc=latest \
|
||||||
|
--module-parameter=torch-disable-jit=yes \
|
||||||
|
--enable-plugin=no-qt \
|
||||||
|
--include-package=numpy \
|
||||||
|
--include-module=numpy.core._multiarray_umath \
|
||||||
|
--include-package=scipy \
|
||||||
|
--no-deployment-flag=self-execution \
|
||||||
|
--include-data-dir="${MODEL_DIR}=${MODEL_DIR}" \
|
||||||
|
--output-dir="${OUT_DIR}" \
|
||||||
|
--remove-output \
|
||||||
|
"${PY_FILE}"
|
||||||
|
|
||||||
|
# 打包结果判断
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
echo -e "\n========================"
|
||||||
|
echo "✅ 打包成功!"
|
||||||
|
echo "📦 产物路径:${OUT_DIR}/${PY_FILE%.py}.exe"
|
||||||
|
echo "========================"
|
||||||
|
else
|
||||||
|
echo -e "\n❌ 打包失败!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Git Bash 兼容的暂停
|
||||||
|
read -n 1 -s -r -p "按任意键退出..."
|
||||||
|
echo
|
||||||
@@ -1,252 +0,0 @@
|
|||||||
0 0.5
|
|
||||||
1 0.5
|
|
||||||
2 0.375
|
|
||||||
3 0.5
|
|
||||||
4 0.4375
|
|
||||||
5 0.375
|
|
||||||
6 0.5
|
|
||||||
7 0.5
|
|
||||||
8 0.375
|
|
||||||
9 0.375
|
|
||||||
10 0.375
|
|
||||||
11 0.375
|
|
||||||
12 0.5
|
|
||||||
13 0.5625
|
|
||||||
14 0.5625
|
|
||||||
15 0.5
|
|
||||||
16 0.5
|
|
||||||
17 0.5
|
|
||||||
18 0.5
|
|
||||||
19 0.5625
|
|
||||||
20 0.4375
|
|
||||||
21 0.5
|
|
||||||
22 0.5
|
|
||||||
23 0.375
|
|
||||||
24 0.375
|
|
||||||
25 0.375
|
|
||||||
26 0.375
|
|
||||||
27 0.375
|
|
||||||
28 0.3125
|
|
||||||
29 0.375
|
|
||||||
30 0.5625
|
|
||||||
31 0.5
|
|
||||||
32 0.5
|
|
||||||
33 0.5625
|
|
||||||
34 0.5625
|
|
||||||
35 0.3125
|
|
||||||
36 0.3125
|
|
||||||
37 0.3125
|
|
||||||
38 0.375
|
|
||||||
39 0.5625
|
|
||||||
40 0.3125
|
|
||||||
41 0.5625
|
|
||||||
42 0.3125
|
|
||||||
43 0.375
|
|
||||||
44 0.5625
|
|
||||||
45 0.5
|
|
||||||
46 0.375
|
|
||||||
47 0.375
|
|
||||||
48 0.3125
|
|
||||||
49 0.375
|
|
||||||
50 0.375
|
|
||||||
51 0.5
|
|
||||||
52 0.5625
|
|
||||||
53 0.375
|
|
||||||
54 0.5625
|
|
||||||
55 0.5625
|
|
||||||
56 0.375
|
|
||||||
57 0.375
|
|
||||||
58 0.375
|
|
||||||
59 0.5
|
|
||||||
60 0.3125
|
|
||||||
61 0.375
|
|
||||||
62 0.375
|
|
||||||
63 0.375
|
|
||||||
64 0.375
|
|
||||||
65 0.375
|
|
||||||
66 0.3125
|
|
||||||
67 0.375
|
|
||||||
68 0.5625
|
|
||||||
69 0.5625
|
|
||||||
70 0.5625
|
|
||||||
71 0.5
|
|
||||||
72 0.5625
|
|
||||||
73 0.375
|
|
||||||
74 0.375
|
|
||||||
75 0.375
|
|
||||||
76 0.375
|
|
||||||
77 0.375
|
|
||||||
78 0.5
|
|
||||||
79 0.375
|
|
||||||
80 0.375
|
|
||||||
81 0.5
|
|
||||||
82 0.375
|
|
||||||
83 0.375
|
|
||||||
84 0.375
|
|
||||||
85 0.375
|
|
||||||
86 0.3125
|
|
||||||
87 0.375
|
|
||||||
88 0.375
|
|
||||||
89 0.5
|
|
||||||
90 0.375
|
|
||||||
91 0.4375
|
|
||||||
92 0.3125
|
|
||||||
93 0.3125
|
|
||||||
94 0.375
|
|
||||||
95 0.375
|
|
||||||
96 0.375
|
|
||||||
97 0.375
|
|
||||||
98 0.3125
|
|
||||||
99 0.4375
|
|
||||||
100 0.375
|
|
||||||
101 0.375
|
|
||||||
102 0.375
|
|
||||||
103 0.3125
|
|
||||||
104 0.5625
|
|
||||||
105 0.5
|
|
||||||
106 0.5625
|
|
||||||
107 0.5625
|
|
||||||
108 0.5
|
|
||||||
109 0.3125
|
|
||||||
110 0.5625
|
|
||||||
111 0.5625
|
|
||||||
112 0.5
|
|
||||||
113 0.3125
|
|
||||||
114 0.5
|
|
||||||
115 0.3125
|
|
||||||
116 0.375
|
|
||||||
117 0.3125
|
|
||||||
118 0.3125
|
|
||||||
119 0.3125
|
|
||||||
120 0.3125
|
|
||||||
121 0.375
|
|
||||||
122 0.375
|
|
||||||
123 0.375
|
|
||||||
124 0.375
|
|
||||||
125 0.3125
|
|
||||||
126 0.375
|
|
||||||
127 0.375
|
|
||||||
128 0.375
|
|
||||||
129 0.375
|
|
||||||
130 0.5625
|
|
||||||
131 0.375
|
|
||||||
132 0.5
|
|
||||||
133 0.3125
|
|
||||||
134 0.3125
|
|
||||||
135 0.3125
|
|
||||||
136 0.375
|
|
||||||
137 0.5
|
|
||||||
138 0.3125
|
|
||||||
139 0.375
|
|
||||||
140 0.3125
|
|
||||||
141 0.3125
|
|
||||||
142 0.3125
|
|
||||||
143 0.5625
|
|
||||||
144 0.3125
|
|
||||||
145 0.375
|
|
||||||
146 0.5
|
|
||||||
147 0.5
|
|
||||||
148 0.375
|
|
||||||
149 0.4375
|
|
||||||
150 0.5
|
|
||||||
151 0.3125
|
|
||||||
152 0.375
|
|
||||||
153 0.375
|
|
||||||
154 0.375
|
|
||||||
155 0.3125
|
|
||||||
156 0.375
|
|
||||||
157 0.4375
|
|
||||||
158 0.4375
|
|
||||||
159 0.375
|
|
||||||
160 0.375
|
|
||||||
161 0.3125
|
|
||||||
162 0.375
|
|
||||||
163 0.375
|
|
||||||
164 0.375
|
|
||||||
165 0.3125
|
|
||||||
166 0.3125
|
|
||||||
167 0.3125
|
|
||||||
168 0.375
|
|
||||||
169 0.3125
|
|
||||||
170 0.3125
|
|
||||||
171 0.3125
|
|
||||||
172 0.375
|
|
||||||
173 0.3125
|
|
||||||
174 0.3125
|
|
||||||
175 0.5
|
|
||||||
176 0.3125
|
|
||||||
177 0.375
|
|
||||||
178 0.375
|
|
||||||
179 0.3125
|
|
||||||
180 0.3125
|
|
||||||
181 0.3125
|
|
||||||
182 0.3125
|
|
||||||
183 0.5625
|
|
||||||
184 0.5625
|
|
||||||
185 0.3125
|
|
||||||
186 0.5
|
|
||||||
187 0.5
|
|
||||||
188 0.5625
|
|
||||||
189 0.5
|
|
||||||
190 0.5625
|
|
||||||
191 0.5625
|
|
||||||
192 0.5625
|
|
||||||
193 0.5
|
|
||||||
194 0.5
|
|
||||||
195 0.5625
|
|
||||||
196 0.5625
|
|
||||||
197 0.5625
|
|
||||||
198 0.5625
|
|
||||||
199 0.5
|
|
||||||
200 0.5625
|
|
||||||
201 0.5625
|
|
||||||
202 0.375
|
|
||||||
203 0.375
|
|
||||||
204 0.375
|
|
||||||
205 0.375
|
|
||||||
206 0.375
|
|
||||||
207 0.5
|
|
||||||
208 0.5
|
|
||||||
209 0.5625
|
|
||||||
210 0.5625
|
|
||||||
211 0.5625
|
|
||||||
212 0.3125
|
|
||||||
213 0.5
|
|
||||||
214 0.5
|
|
||||||
215 0.5625
|
|
||||||
216 0.5
|
|
||||||
217 0.5
|
|
||||||
218 0.5
|
|
||||||
219 0.5625
|
|
||||||
220 0.5
|
|
||||||
221 0.4375
|
|
||||||
222 0.5
|
|
||||||
223 0.5
|
|
||||||
224 0.4375
|
|
||||||
225 0.5
|
|
||||||
226 0.4375
|
|
||||||
227 0.5
|
|
||||||
228 0.5
|
|
||||||
229 0.375
|
|
||||||
230 0.375
|
|
||||||
231 0.3125
|
|
||||||
232 0.375
|
|
||||||
233 0.375
|
|
||||||
234 0.375
|
|
||||||
235 0.5625
|
|
||||||
236 0.5625
|
|
||||||
237 0.5625
|
|
||||||
238 0.5625
|
|
||||||
239 0.5625
|
|
||||||
240 0.5
|
|
||||||
241 0.5
|
|
||||||
242 0.5
|
|
||||||
243 0.5625
|
|
||||||
244 0.5625
|
|
||||||
245 0.375
|
|
||||||
246 0.375
|
|
||||||
247 0.375
|
|
||||||
248 0.3125
|
|
||||||
249 0.375
|
|
||||||
The average accuracy is: 0.42675
|
|
||||||
The best accuracy is: 0.5625
|
|
||||||
52
requirements.txt
Normal file
52
requirements.txt
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
Bottleneck==1.4.2
|
||||||
|
brotlicffi==1.2.0.0
|
||||||
|
certifi==2026.5.20
|
||||||
|
cffi==2.0.0
|
||||||
|
charset-normalizer==3.4.4
|
||||||
|
contourpy==1.3.2
|
||||||
|
cycler==0.12.1
|
||||||
|
einops==0.8.2
|
||||||
|
filelock==3.20.3
|
||||||
|
fonttools==4.63.0
|
||||||
|
gmpy2==2.2.2
|
||||||
|
idna==3.11
|
||||||
|
Jinja2==3.1.6
|
||||||
|
joblib==1.5.3
|
||||||
|
kiwisolver==1.5.0
|
||||||
|
MarkupSafe==3.0.2
|
||||||
|
matplotlib==3.10.9
|
||||||
|
mkl_fft==1.3.11
|
||||||
|
mkl_random==1.2.8
|
||||||
|
mkl-service==2.5.2
|
||||||
|
mpmath==1.3.0
|
||||||
|
networkx==3.4.2
|
||||||
|
Nuitka==4.1.1
|
||||||
|
numexpr==2.14.1
|
||||||
|
numpy==1.24.3
|
||||||
|
packaging==26.0
|
||||||
|
pandas==2.3.3
|
||||||
|
pillow==12.2.0
|
||||||
|
pip==26.0.1
|
||||||
|
pycparser==3.0
|
||||||
|
pyparsing==3.3.2
|
||||||
|
pyserial==3.5
|
||||||
|
PySocks==1.7.1
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
pytz==2026.1.post1
|
||||||
|
pyzmq==27.1.0
|
||||||
|
requests==2.33.1
|
||||||
|
scikit-learn==1.7.1
|
||||||
|
scipy==1.15.3
|
||||||
|
setuptools==82.0.1
|
||||||
|
six==1.17.0
|
||||||
|
sympy==1.14.0
|
||||||
|
threadpoolctl==3.5.0
|
||||||
|
torch==2.0.0
|
||||||
|
torchaudio==2.0.0
|
||||||
|
torchsummary==1.5.1
|
||||||
|
torchvision==0.15.0
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
tzdata==2026.2
|
||||||
|
urllib3==2.7.0
|
||||||
|
wheel==0.46.3
|
||||||
|
win_inet_pton==1.1.0
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import matplotlib
|
# import matplotlib
|
||||||
matplotlib.use('Agg')
|
# matplotlib.use('Agg')
|
||||||
import argparse
|
# import argparse
|
||||||
import sys
|
# import sys
|
||||||
import time
|
import time
|
||||||
from Decoder import Decoder_main
|
from Decoder import Decoder_main
|
||||||
from PubLibrary.RunOnce import is_program_running
|
from PubLibrary.RunOnce import is_program_running
|
||||||
|
|||||||
422
system_test.py
422
system_test.py
@@ -1,422 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
ZMQ 脑电数据测试工具【语法错误修复版】
|
|
||||||
修复点:
|
|
||||||
1. dataclass 可变列表默认值报错
|
|
||||||
2. threading.Thread daemon 参数语法错误
|
|
||||||
适配:Python3.10、全链路 float64、ZMQ DEALER<->ROUTER
|
|
||||||
端口:8099(命令) / 8100(数据)
|
|
||||||
"""
|
|
||||||
import zmq
|
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, List, Optional, Union, Tuple
|
|
||||||
from matplotlib.animation import FuncAnimation
|
|
||||||
|
|
||||||
# ===================== 1. 配置管理 =====================
|
|
||||||
@dataclass(frozen=True) # 冻结配置类
|
|
||||||
class TestConfig:
|
|
||||||
# 网络配置
|
|
||||||
SERVER_IP: str = "127.0.0.1"
|
|
||||||
CMD_PORT: int = 8099
|
|
||||||
DATA_PORT: int = 8100
|
|
||||||
|
|
||||||
# 硬件与时序
|
|
||||||
SAMPLE_RATE: int = 250
|
|
||||||
FRAME_INTERVAL_MS: int = 20
|
|
||||||
SEND_INTERVAL: float = FRAME_INTERVAL_MS / 1000
|
|
||||||
CHANNEL_NUMS: int = 66
|
|
||||||
FRAME_POINTS: int = 5
|
|
||||||
FILTER_OUT_CHAN: int = 64
|
|
||||||
FILTER_FRAME_POINTS: int = 50
|
|
||||||
|
|
||||||
# 数据类型 & 字节数 (float64 8字节)
|
|
||||||
DATA_DTYPE: np.dtype = np.float64
|
|
||||||
RAW_FRAME_BYTES: int = CHANNEL_NUMS * FRAME_POINTS * 8 # 66*5*8 = 2640
|
|
||||||
FILTER_FRAME_BYTES: int = FILTER_OUT_CHAN * FILTER_FRAME_POINTS * 8 # 25600
|
|
||||||
|
|
||||||
# 事件通道索引
|
|
||||||
EVENT_CHANNEL_IDX: int = -2
|
|
||||||
|
|
||||||
# 列表类型 使用 default_factory 规避可变默认值报错
|
|
||||||
EVENT_TAGS: List[int] = field(default_factory=lambda: [1, 2, 99])
|
|
||||||
SIM_SIGNAL_FREQ: List[float] = field(default_factory=lambda: [8.0, 9.0])
|
|
||||||
|
|
||||||
# 仿真噪声
|
|
||||||
NOISE_STD: float = 0.25
|
|
||||||
|
|
||||||
# 可视化配置
|
|
||||||
PLOT_TARGET_CHAN: int = 0
|
|
||||||
PLOT_WINDOW_LEN: int = 400
|
|
||||||
PLOT_REFRESH_INTERVAL: int = 50
|
|
||||||
|
|
||||||
# 日志限流
|
|
||||||
FRAME_ERR_INTERVAL: float = 3.0
|
|
||||||
|
|
||||||
# ZMQ 配置
|
|
||||||
SEND_RETRY_MAX: int = 3
|
|
||||||
SEND_RETRY_SLEEP: float = 0.01
|
|
||||||
ZMQ_HWM: int = 1000
|
|
||||||
|
|
||||||
# 初始化全局配置
|
|
||||||
CONFIG = TestConfig()
|
|
||||||
|
|
||||||
# ===================== 2. 全局状态管理 =====================
|
|
||||||
class GlobalState:
|
|
||||||
def __init__(self):
|
|
||||||
self.run_flag: bool = True
|
|
||||||
self.last_frame_err_time: float = 0.0
|
|
||||||
|
|
||||||
GLOBAL_STATE = GlobalState()
|
|
||||||
|
|
||||||
# ===================== 3. Matplotlib 中文初始化 =====================
|
|
||||||
def init_matplotlib():
|
|
||||||
# Windows 黑体,Linux/Mac 自行替换字体
|
|
||||||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
|
||||||
plt.rcParams['axes.unicode_minus'] = False # 修复负号乱码
|
|
||||||
|
|
||||||
init_matplotlib()
|
|
||||||
|
|
||||||
# ===================== 4. ZMQ DEALER 客户端 =====================
|
|
||||||
class ZmqDealerClient:
|
|
||||||
"""适配 ROUTER 的 DEALER 客户端,高频流式数据专用"""
|
|
||||||
def __init__(self, server_ip: str, port: int):
|
|
||||||
self.ctx: zmq.Context = zmq.Context()
|
|
||||||
self.socket: zmq.Socket = self.ctx.socket(zmq.DEALER)
|
|
||||||
self._configure_socket()
|
|
||||||
self.socket.connect(f"tcp://{server_ip}:{port}")
|
|
||||||
|
|
||||||
def _configure_socket(self):
|
|
||||||
"""套接字参数配置"""
|
|
||||||
self.socket.setsockopt(zmq.RCVHWM, CONFIG.ZMQ_HWM)
|
|
||||||
self.socket.setsockopt(zmq.SNDHWM, CONFIG.ZMQ_HWM)
|
|
||||||
self.socket.setsockopt(zmq.RCVTIMEO, 0)
|
|
||||||
self.socket.setsockopt(zmq.SNDTIMEO, 0)
|
|
||||||
|
|
||||||
def send_json(self, data: Dict) -> bool:
|
|
||||||
"""发送JSON命令,带重试机制"""
|
|
||||||
try:
|
|
||||||
payload = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[JSON序列化失败] {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
for _ in range(CONFIG.SEND_RETRY_MAX):
|
|
||||||
try:
|
|
||||||
self.socket.send_multipart([b"", payload])
|
|
||||||
return True
|
|
||||||
except zmq.Again:
|
|
||||||
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[JSON发送异常] {e}")
|
|
||||||
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
|
||||||
print(f"[JSON发送重试失败]")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def send_bytes(self, data: bytes) -> bool:
|
|
||||||
"""发送二进制脑电数据,带重试"""
|
|
||||||
for _ in range(CONFIG.SEND_RETRY_MAX):
|
|
||||||
try:
|
|
||||||
self.socket.send_multipart([b"", data])
|
|
||||||
return True
|
|
||||||
except zmq.Again:
|
|
||||||
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[二进制发送异常] {e}")
|
|
||||||
time.sleep(CONFIG.SEND_RETRY_SLEEP)
|
|
||||||
print(f"[二进制发送重试失败]")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def recv_json(self) -> Optional[Dict]:
|
|
||||||
"""接收JSON命令响应(标准3帧)"""
|
|
||||||
try:
|
|
||||||
frames = self.socket.recv_multipart()
|
|
||||||
if len(frames) < 3:
|
|
||||||
self._log_frame_err(f"帧数异常: {len(frames)}")
|
|
||||||
return None
|
|
||||||
payload = frames[2].decode("utf-8")
|
|
||||||
return json.loads(payload)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
self._log_frame_err("JSON解析失败")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
self._log_frame_err(f"接收异常: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def recv_bytes(self) -> Optional[bytes]:
|
|
||||||
"""接收滤波数据,兼容3/4帧格式"""
|
|
||||||
try:
|
|
||||||
frames = self.socket.recv_multipart()
|
|
||||||
frame_len = len(frames)
|
|
||||||
if frame_len == 3:
|
|
||||||
payload = frames[2]
|
|
||||||
elif frame_len == 4:
|
|
||||||
payload = frames[3]
|
|
||||||
else:
|
|
||||||
self._log_frame_err(f"帧数异常: {frame_len}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if len(payload) != CONFIG.FILTER_FRAME_BYTES:
|
|
||||||
self._log_frame_err(f"字节不匹配: 期望{CONFIG.FILTER_FRAME_BYTES}, 实际{len(payload)}")
|
|
||||||
return None
|
|
||||||
return payload
|
|
||||||
except Exception as e:
|
|
||||||
self._log_frame_err(f"数据接收异常: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _log_frame_err(self, msg: str):
|
|
||||||
"""日志限流,防止刷屏"""
|
|
||||||
now = time.time()
|
|
||||||
if now - GLOBAL_STATE.last_frame_err_time > CONFIG.FRAME_ERR_INTERVAL:
|
|
||||||
print(f"[帧异常] {msg}")
|
|
||||||
GLOBAL_STATE.last_frame_err_time = now
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""优雅释放ZMQ资源"""
|
|
||||||
try:
|
|
||||||
self.socket.close(linger=0)
|
|
||||||
self.ctx.term()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[资源释放异常] {e}")
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
# ===================== 5. 仿真脑电数据生成 =====================
|
|
||||||
def generate_raw_eeg_frame(add_event: bool = False) -> np.ndarray:
|
|
||||||
"""生成单帧float64仿真脑电数据"""
|
|
||||||
t = np.linspace(
|
|
||||||
0, CONFIG.FRAME_POINTS / CONFIG.SAMPLE_RATE,
|
|
||||||
CONFIG.FRAME_POINTS, endpoint=False
|
|
||||||
)
|
|
||||||
eeg_frame = np.zeros(
|
|
||||||
(CONFIG.CHANNEL_NUMS, CONFIG.FRAME_POINTS),
|
|
||||||
dtype=CONFIG.DATA_DTYPE
|
|
||||||
)
|
|
||||||
|
|
||||||
# 模拟脑电信号 + 高斯噪声
|
|
||||||
for freq in CONFIG.SIM_SIGNAL_FREQ:
|
|
||||||
eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.sin(2 * np.pi * freq * t)
|
|
||||||
eeg_frame[:CONFIG.FILTER_OUT_CHAN] += np.random.normal(
|
|
||||||
0, CONFIG.NOISE_STD,
|
|
||||||
size=(CONFIG.FILTER_OUT_CHAN, CONFIG.FRAME_POINTS)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 事件通道处理
|
|
||||||
eeg_frame[CONFIG.EVENT_CHANNEL_IDX] = 0.0
|
|
||||||
if add_event:
|
|
||||||
event_pos = np.random.randint(0, CONFIG.FRAME_POINTS)
|
|
||||||
eeg_frame[CONFIG.EVENT_CHANNEL_IDX, event_pos] = np.random.choice(CONFIG.EVENT_TAGS)
|
|
||||||
|
|
||||||
# 预留通道置0
|
|
||||||
eeg_frame[-1] = 0.0
|
|
||||||
return eeg_frame
|
|
||||||
|
|
||||||
# ===================== 6. 后台工作线程 =====================
|
|
||||||
def start_cmd_response_thread(cmd_client: ZmqDealerClient):
|
|
||||||
"""命令响应接收线程"""
|
|
||||||
print("[线程-命令接收] 已启动")
|
|
||||||
while GLOBAL_STATE.run_flag:
|
|
||||||
msg = cmd_client.recv_json()
|
|
||||||
if msg:
|
|
||||||
print(f"\n【命令响应】{json.dumps(msg, ensure_ascii=False, indent=2)}")
|
|
||||||
time.sleep(0.01)
|
|
||||||
print("[线程-命令接收] 已退出")
|
|
||||||
|
|
||||||
def start_raw_eeg_send_thread(data_client: ZmqDealerClient):
|
|
||||||
"""原始脑电发送线程(20ms/帧)"""
|
|
||||||
print(f"[线程-原始数据发送] 20ms/帧 | 单帧{CONFIG.RAW_FRAME_BYTES}字节 | float64")
|
|
||||||
frame_count = 0
|
|
||||||
while GLOBAL_STATE.run_flag:
|
|
||||||
insert_event = (frame_count % 20 == 0)
|
|
||||||
eeg_frame = generate_raw_eeg_frame(add_event=insert_event)
|
|
||||||
frame_bytes = eeg_frame.tobytes()
|
|
||||||
|
|
||||||
# 字节校验
|
|
||||||
if len(frame_bytes) != CONFIG.RAW_FRAME_BYTES:
|
|
||||||
print(f"[字节警告] 期望{CONFIG.RAW_FRAME_BYTES}, 实际{len(frame_bytes)}")
|
|
||||||
time.sleep(CONFIG.SEND_INTERVAL)
|
|
||||||
frame_count += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
data_client.send_bytes(frame_bytes)
|
|
||||||
frame_count += 1
|
|
||||||
time.sleep(CONFIG.SEND_INTERVAL)
|
|
||||||
print("[线程-原始数据发送] 已退出")
|
|
||||||
|
|
||||||
def start_filter_data_recv_thread(data_client: ZmqDealerClient, plot_queue: List[np.ndarray]):
|
|
||||||
"""滤波数据接收线程"""
|
|
||||||
print(f"[线程-滤波数据接收] 单包{CONFIG.FILTER_FRAME_BYTES}字节 | float64")
|
|
||||||
while GLOBAL_STATE.run_flag:
|
|
||||||
raw_bytes = data_client.recv_bytes()
|
|
||||||
if not raw_bytes:
|
|
||||||
time.sleep(0.01)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
filter_arr = np.frombuffer(raw_bytes, dtype=CONFIG.DATA_DTYPE)
|
|
||||||
filter_arr = filter_arr.reshape(CONFIG.FILTER_FRAME_POINTS, CONFIG.FILTER_OUT_CHAN)
|
|
||||||
plot_queue.append(filter_arr[:, CONFIG.PLOT_TARGET_CHAN])
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[滤波数据解析异常] {e}")
|
|
||||||
continue
|
|
||||||
print("[线程-滤波数据接收] 已退出")
|
|
||||||
|
|
||||||
# ===================== 7. 实时波形可视化 =====================
|
|
||||||
def start_wave_visualization(plot_queue: List[np.ndarray]):
|
|
||||||
"""启动实时滤波波形绘图"""
|
|
||||||
fig, ax = plt.subplots(figsize=(14, 4))
|
|
||||||
x_axis = np.arange(0, CONFIG.PLOT_WINDOW_LEN)
|
|
||||||
wave_data = np.zeros(CONFIG.PLOT_WINDOW_LEN, dtype=CONFIG.DATA_DTYPE)
|
|
||||||
line, = ax.plot(x_axis, wave_data, color="#2E86AB", linewidth=1.2)
|
|
||||||
|
|
||||||
ax.set_title(
|
|
||||||
f"实时滤波脑电波形 | 通道 {CONFIG.PLOT_TARGET_CHAN} | {CONFIG.SAMPLE_RATE}Hz | float64",
|
|
||||||
fontsize=12
|
|
||||||
)
|
|
||||||
ax.set_ylim(-3.0, 3.0)
|
|
||||||
ax.grid(True, alpha=0.3, linestyle="--")
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
def update_plot(_):
|
|
||||||
nonlocal wave_data
|
|
||||||
if plot_queue:
|
|
||||||
new_wave = plot_queue.pop(0)
|
|
||||||
wave_data = np.roll(wave_data, -len(new_wave))
|
|
||||||
wave_data[-len(new_wave)] = new_wave
|
|
||||||
line.set_ydata(wave_data)
|
|
||||||
return (line,)
|
|
||||||
|
|
||||||
ani = FuncAnimation(
|
|
||||||
fig, update_plot,
|
|
||||||
interval=CONFIG.PLOT_REFRESH_INTERVAL,
|
|
||||||
blit=True,
|
|
||||||
cache_frame_data=False
|
|
||||||
)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# ===================== 8. 全量业务测试用例 =====================
|
|
||||||
def run_full_test_cases(cmd_client: ZmqDealerClient):
|
|
||||||
"""全覆盖 zmqServer 所有命令:sync/targetFreqs/decoderClass/impedance/train/predict/rest"""
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("开始执行全量命令测试用例")
|
|
||||||
print("="*60)
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
# 1. 同步命令
|
|
||||||
print("\n[用例 1] 发送 sync 命令")
|
|
||||||
cmd_client.send_json({"method": "sync", "params": {}})
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# 2. 设置目标频率
|
|
||||||
print("\n[用例 2] 发送 targetFreqs = [8.0, 9.0]")
|
|
||||||
cmd_client.send_json({"method": "targetFreqs", "params": [8.0, 9.0]})
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# 3. 切换解码器
|
|
||||||
print("\n[用例 3] 切换解码器为 ssmvep")
|
|
||||||
cmd_client.send_json({"method": "decoderClass", "params": "ssmvep"})
|
|
||||||
time.sleep(2)
|
|
||||||
print("\n[用例 3-2] 切换解码器为 mi")
|
|
||||||
cmd_client.send_json({"method": "decoderClass", "params": "mi"})
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
# 4. 阻抗检测开关
|
|
||||||
print("\n[用例 4] 开启阻抗检测 impedance=1")
|
|
||||||
cmd_client.send_json({"method": "impedance", "params": 1})
|
|
||||||
time.sleep(1)
|
|
||||||
print("\n[用例 4-2] 关闭阻抗检测 impedance=2")
|
|
||||||
cmd_client.send_json({"method": "impedance", "params": 2})
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# 5. 训练模式
|
|
||||||
print("\n[用例 5] 启动训练 train,标签=1")
|
|
||||||
cmd_client.send_json({"method": "train", "params": 1})
|
|
||||||
time.sleep(3)
|
|
||||||
|
|
||||||
# # 6. 休息模式
|
|
||||||
# print("\n[用例 6] 切换 rest 休息模式")
|
|
||||||
# cmd_client.send_json({"method": "rest", "params": {}})
|
|
||||||
# time.sleep(1)
|
|
||||||
|
|
||||||
# 7. 启动解码
|
|
||||||
print("\n[用例 7] 启动解码 predict=1")
|
|
||||||
cmd_client.send_json({"method": "predict", "params": 1})
|
|
||||||
time.sleep(4)
|
|
||||||
|
|
||||||
# # 8. 非法命令(异常测试)
|
|
||||||
# print("\n[用例 8] 发送非法命令 test_cmd_illegal")
|
|
||||||
# cmd_client.send_json({"method": "test_cmd_illegal", "params": {}})
|
|
||||||
# time.sleep(1)
|
|
||||||
|
|
||||||
# # 9. 停止解码
|
|
||||||
# print("\n[用例 9] 停止解码 predict=2")
|
|
||||||
# cmd_client.send_json({"method": "predict", "params": 2})
|
|
||||||
# time.sleep(2)
|
|
||||||
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("所有测试用例执行完毕")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
# ===================== 主程序入口(修复线程语法) =====================
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("="*60)
|
|
||||||
print("ZMQ 脑电仿真测试工具 启动")
|
|
||||||
print(f"命令端口: {CONFIG.CMD_PORT} | 数据端口: {CONFIG.DATA_PORT}")
|
|
||||||
print(f"原始帧{CONFIG.RAW_FRAME_BYTES}字节 | 滤波帧{CONFIG.FILTER_FRAME_BYTES}字节 | float64")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.CMD_PORT) as cmd_client, \
|
|
||||||
ZmqDealerClient(CONFIG.SERVER_IP, CONFIG.DATA_PORT) as data_client:
|
|
||||||
|
|
||||||
plot_queue = []
|
|
||||||
|
|
||||||
# ========== 重点修复:线程语法,daemon 移出 args ==========
|
|
||||||
# 命令接收线程
|
|
||||||
t_cmd = threading.Thread(
|
|
||||||
target=start_cmd_response_thread,
|
|
||||||
args=(cmd_client,), # 单元素元组保留逗号
|
|
||||||
daemon=True
|
|
||||||
)
|
|
||||||
# 原始数据发送线程
|
|
||||||
t_eeg = threading.Thread(
|
|
||||||
target=start_raw_eeg_send_thread,
|
|
||||||
args=(data_client,),
|
|
||||||
daemon=True
|
|
||||||
)
|
|
||||||
# 滤波数据接收线程
|
|
||||||
t_filter = threading.Thread(
|
|
||||||
target=start_filter_data_recv_thread,
|
|
||||||
args=(data_client, plot_queue),
|
|
||||||
daemon=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 启动线程
|
|
||||||
t_cmd.start()
|
|
||||||
t_eeg.start()
|
|
||||||
t_filter.start()
|
|
||||||
|
|
||||||
# 执行测试用例
|
|
||||||
run_full_test_cases(cmd_client)
|
|
||||||
|
|
||||||
# 启动可视化(阻塞主线程)
|
|
||||||
print("\n[提示] 波形窗口已启动,关闭窗口 / Ctrl+C 退出程序")
|
|
||||||
start_wave_visualization(plot_queue)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n\n[用户中断] 接收到 Ctrl+C,准备退出...")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n[程序异常] {e}")
|
|
||||||
finally:
|
|
||||||
# 停止所有后台线程
|
|
||||||
GLOBAL_STATE.run_flag = False
|
|
||||||
time.sleep(0.2)
|
|
||||||
print("程序已安全退出")
|
|
||||||
306
upperHost_stimmock/MI_headless.py
Normal file
306
upperHost_stimmock/MI_headless.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
"""
|
||||||
|
MI_headless.py
|
||||||
|
无界面版 MI 运动想象范式通讯流程模拟脚本。
|
||||||
|
复现 MI_main.py 的完整指令序列(train 0/1, rest, predict, saveData),
|
||||||
|
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
|
||||||
|
|
||||||
|
启动顺序:
|
||||||
|
1. runDecoder.py
|
||||||
|
2. datamock.py
|
||||||
|
3. MI_headless.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
import ast
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
|
personname = 'demo'
|
||||||
|
session = '01'
|
||||||
|
|
||||||
|
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
|
||||||
|
|
||||||
|
|
||||||
|
# ========== ZMQ 结果接收服务 ==========
|
||||||
|
class ZmqResultServer(threading.Thread):
|
||||||
|
def __init__(self, port=8088):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.port = port
|
||||||
|
self.running = True
|
||||||
|
self.energy = 0
|
||||||
|
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
|
||||||
|
self.ChoosenNum = -1
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.ROUTER)
|
||||||
|
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
|
||||||
|
self.daemon = True
|
||||||
|
self.trial_idx = 0
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
print(f"[Server] UpperHost_Server listening on {self.port}")
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||||
|
if len(frames) < 3:
|
||||||
|
continue
|
||||||
|
message = json.loads(frames[2].decode('utf-8'))
|
||||||
|
method = message.get('method')
|
||||||
|
params = message.get('params')
|
||||||
|
if method == 'energy':
|
||||||
|
self.energy = params
|
||||||
|
elif method == 'paradigm':
|
||||||
|
self.paradigm = params
|
||||||
|
print(f"[Server] paradigm -> {params}")
|
||||||
|
elif method == 'result':
|
||||||
|
self.ChoosenNum = params
|
||||||
|
self.trial_idx += 1
|
||||||
|
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.005)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Server] error: {e}")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== ZMQ 命令发送客户端 ==========
|
||||||
|
class ZmqCmdClient:
|
||||||
|
def __init__(self, host, port):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.DEALER)
|
||||||
|
# PUSH socket 用于向 datamock.py 发送标签命令
|
||||||
|
self._label_sock = self.context.socket(zmq.PUSH)
|
||||||
|
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
|
||||||
|
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||||
|
print(f"[Client] connected to {self.host}:{self.port}")
|
||||||
|
|
||||||
|
def start_recv_thread(self, result_server):
|
||||||
|
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
|
||||||
|
self._result_server = result_server
|
||||||
|
self._stop_recv = threading.Event()
|
||||||
|
|
||||||
|
def _recv_loop():
|
||||||
|
while not self._stop_recv.is_set():
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||||
|
# DEALER 收到的格式: [b'', json_bytes]
|
||||||
|
data_bytes = frames[-1]
|
||||||
|
message = json.loads(data_bytes.decode('utf-8'))
|
||||||
|
method = message.get('method')
|
||||||
|
params = message.get('params')
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print(f"[{ts}] [CmdClient] recv: {method}={params}")
|
||||||
|
if method == 'paradigm':
|
||||||
|
self._result_server.paradigm = params
|
||||||
|
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
|
||||||
|
elif method == 'result':
|
||||||
|
self._result_server.ChoosenNum = params
|
||||||
|
self._result_server.trial_idx += 1
|
||||||
|
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
|
||||||
|
elif method == 'energy':
|
||||||
|
self._result_server.energy = params
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.005)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[CmdClient recv] error: {e}")
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
|
||||||
|
self._recv_thread.start()
|
||||||
|
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
|
||||||
|
|
||||||
|
def stop_recv_thread(self):
|
||||||
|
if hasattr(self, '_stop_recv'):
|
||||||
|
self._stop_recv.set()
|
||||||
|
|
||||||
|
def _send_label(self, label_value):
|
||||||
|
"""向 datamock.py 发送标签命令"""
|
||||||
|
try:
|
||||||
|
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Client] label send error: {e}")
|
||||||
|
|
||||||
|
def send_data(self, method, params):
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
try:
|
||||||
|
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print(f"[{ts}] send_data: {method}={params}")
|
||||||
|
# 根据 train/predict 命令向 datamock 发送标签
|
||||||
|
if method == 'train':
|
||||||
|
if params == 0:
|
||||||
|
self._send_label(1)
|
||||||
|
print(f"[Label] train 0 -> datamock label=1")
|
||||||
|
elif params == 1:
|
||||||
|
self._send_label(2)
|
||||||
|
print(f"[Label] train 1 -> datamock label=2")
|
||||||
|
elif method == 'predict':
|
||||||
|
self._send_label(99)
|
||||||
|
print(f"[Label] predict -> datamock label=99")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Client] send error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 主流程 ==========
|
||||||
|
def run_headless():
|
||||||
|
server = ZmqResultServer(port=8088)
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
_dh = str(IniRead('system', 'Decoder_Host'))
|
||||||
|
_dp = int(IniRead('system', 'Decoder_Port'))
|
||||||
|
client = ZmqCmdClient(_dh, _dp)
|
||||||
|
client.connect()
|
||||||
|
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
|
||||||
|
|
||||||
|
time.sleep(1) # 等待连接建立
|
||||||
|
client.send_data('decoderClass', 'mi')
|
||||||
|
time.sleep(4) # 等待 zmqServer 排空启动积压包(datamock 提前连接会积压 ~3s 数据)
|
||||||
|
|
||||||
|
# MI_IntervalEpoch = [0.5, 4.5],trial时长 = 4.5-0.5 = 4.0s
|
||||||
|
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5]
|
||||||
|
_trial_sec = float(_mi_iv[1] - _mi_iv[0]) # 4.0s
|
||||||
|
_margin = 1.0
|
||||||
|
train_time = max(5.0, _trial_sec + _margin) # 训练刺激时长(与 MI_main.py 保持一致)
|
||||||
|
|
||||||
|
# MI epoch latency = interval_epoch[1] // 5 = (4.5*250)//5 = 225包 × 20ms = 4.5s
|
||||||
|
# train_latency = 225包(MI中 train_latency == latency)
|
||||||
|
# 在 train_time 后需再等 epoch_wait 秒,decoder 才能完成 epoch 采集
|
||||||
|
epoch_wait = _mi_iv[1] / _mi_iv[1] * (_mi_iv[1] * 250 // 5) * 0.02 # = latency * 20ms
|
||||||
|
# 更直接的计算:latency = interval_epoch[1] // 5 = int(4.5*250)//5 = 225,225*0.02 = 4.5s
|
||||||
|
epoch_wait = (int(_mi_iv[1] * 250) // 5) * 0.02 # 4.5s
|
||||||
|
|
||||||
|
# predict epoch wait(与 train 相同,MI中 latency == train_latency)
|
||||||
|
predict_epoch_wait = epoch_wait # 4.5s
|
||||||
|
|
||||||
|
test_time = 7.0 # 预测窗口时长(与 MI_main.py 保持一致)
|
||||||
|
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
|
||||||
|
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
|
||||||
|
rest_time = float(IniRead('system', 'Rest_time'))
|
||||||
|
|
||||||
|
num_blocks = int(IniRead('system', 'Num_blocks'))
|
||||||
|
num_trials = int(IniRead('system', 'Num_trials'))
|
||||||
|
|
||||||
|
trained = 0
|
||||||
|
Num_Total = 0
|
||||||
|
Num_Success = 0
|
||||||
|
user_choice = []
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
print("[Headless] 开始运行 MI 通讯流程(无界面)")
|
||||||
|
print(f" MI_IntervalEpoch={_mi_iv}, trial_sec={_trial_sec:.2f}s")
|
||||||
|
print(f" train_time={train_time:.2f}s, epoch_wait={epoch_wait:.2f}s")
|
||||||
|
print(f" test_time={test_time:.2f}s, predict_epoch_wait={predict_epoch_wait:.2f}s")
|
||||||
|
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# -------- 个体校准阶段 --------
|
||||||
|
print("\n[Phase] 个体校准阶段 (paradigm=0)")
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
while server.paradigm == 0:
|
||||||
|
# 左侧 MI 刺激(train 0,label=1)
|
||||||
|
print(f"\n[Train] 左侧 MI 刺激 (train 0) trained={trained}")
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(0.5) # ding 提示后等待
|
||||||
|
|
||||||
|
client.send_data('train', 0)
|
||||||
|
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(1.0) # 类间休息
|
||||||
|
|
||||||
|
# 空闲态样本采集(train 1,label=2)
|
||||||
|
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
|
||||||
|
client.send_data('train', 1)
|
||||||
|
time.sleep(train_time + 0.2) # 等待刺激时间 + epoch 完成时间
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(1.0) # 类间休息
|
||||||
|
|
||||||
|
# 个体校准阶段结束
|
||||||
|
print("\n[Phase] 个体校准结束,等待模型训练 (paradigm=2) ...")
|
||||||
|
trained = 0
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# 等待模型训练完成 (paradigm=2 -> paradigm=1)
|
||||||
|
while server.paradigm == 2:
|
||||||
|
print("[Phase] 等待模型训练完成 ...")
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
# -------- 康复训练阶段 --------
|
||||||
|
while server.paradigm == 1:
|
||||||
|
print("\n[Phase] 康复训练阶段 (paradigm=1)")
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
print(f"\n [Block {block_idx+1}/{num_blocks}]")
|
||||||
|
time.sleep(10) # 每轮开始前等待
|
||||||
|
|
||||||
|
for trial_idx in range(num_trials):
|
||||||
|
print(f" [Trial {trial_idx+1}/{num_trials}]")
|
||||||
|
|
||||||
|
time.sleep(0.5) # ding 提示
|
||||||
|
server.ChoosenNum = -1
|
||||||
|
|
||||||
|
# 开始预测
|
||||||
|
# MI predict epoch latency = 225包 × 20ms = 4.5s,需额外等待 epoch 完成
|
||||||
|
client.send_data('predict', 1)
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
|
||||||
|
if server.ChoosenNum >= 0:
|
||||||
|
Num_Total += 1
|
||||||
|
user_choice.append(server.ChoosenNum)
|
||||||
|
if server.ChoosenNum == 0:
|
||||||
|
Num_Success += 1
|
||||||
|
rest_time = right_rehabilitation
|
||||||
|
elif server.ChoosenNum == 1:
|
||||||
|
rest_time = fault_rehabilitation
|
||||||
|
break
|
||||||
|
time.sleep(0.02)
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(0.5)
|
||||||
|
time.sleep(rest_time)
|
||||||
|
server.ChoosenNum = -1
|
||||||
|
|
||||||
|
# 训练结束
|
||||||
|
print("\n[Phase] 康复训练结束")
|
||||||
|
break # 退出康复训练循环
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
|
||||||
|
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
|
||||||
|
print(f"[Result] user_choice={user_choice}")
|
||||||
|
break # 完成一个完整流程后退出
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n[Headless] 用户中断")
|
||||||
|
finally:
|
||||||
|
client.send_data('predict', 2) # 关闭系统
|
||||||
|
client.send_data('saveData', 0)
|
||||||
|
server.stop()
|
||||||
|
print("[Headless] 已发送关闭指令,退出。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_headless()
|
||||||
301
upperHost_stimmock/ssmvep_headless.py
Normal file
301
upperHost_stimmock/ssmvep_headless.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""
|
||||||
|
ssmvep_headless.py
|
||||||
|
无界面版 SSMVEP 范式通讯流程模拟脚本。
|
||||||
|
复现 ssmvep_main.py 的完整指令序列(train 0/1/2, rest, predict, saveData),
|
||||||
|
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
|
||||||
|
|
||||||
|
启动顺序:
|
||||||
|
1. runDecoder.py
|
||||||
|
2. datamock.py
|
||||||
|
3. ssmvep_headless.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
|
personname = 'demo'
|
||||||
|
session = '01'
|
||||||
|
|
||||||
|
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
|
||||||
|
|
||||||
|
|
||||||
|
# ========== ZMQ 结果接收服务 ==========
|
||||||
|
class ZmqResultServer(threading.Thread):
|
||||||
|
def __init__(self, port=8088):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.port = port
|
||||||
|
self.running = True
|
||||||
|
self.energy = 0
|
||||||
|
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
|
||||||
|
self.ChoosenNum = -1
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.ROUTER)
|
||||||
|
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
|
||||||
|
self.daemon = True
|
||||||
|
self.trial_idx = 0
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
print(f"[Server] UpperHost_Server listening on {self.port}")
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||||
|
if len(frames) < 3:
|
||||||
|
continue
|
||||||
|
message = json.loads(frames[2].decode('utf-8'))
|
||||||
|
method = message.get('method')
|
||||||
|
params = message.get('params')
|
||||||
|
if method == 'energy':
|
||||||
|
self.energy = params
|
||||||
|
elif method == 'paradigm':
|
||||||
|
self.paradigm = params
|
||||||
|
print(f"[Server] paradigm -> {params}")
|
||||||
|
elif method == 'result':
|
||||||
|
self.ChoosenNum = params
|
||||||
|
self.trial_idx += 1
|
||||||
|
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.005)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Server] error: {e}")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== ZMQ 命令发送客户端 ==========
|
||||||
|
class ZmqCmdClient:
|
||||||
|
def __init__(self, host, port):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.DEALER)
|
||||||
|
# PUSH socket 用于向 datamock.py 发送标签命令
|
||||||
|
self._label_sock = self.context.socket(zmq.PUSH)
|
||||||
|
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
|
||||||
|
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||||
|
print(f"[Client] connected to {self.host}:{self.port}")
|
||||||
|
|
||||||
|
def start_recv_thread(self, result_server):
|
||||||
|
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
|
||||||
|
self._result_server = result_server
|
||||||
|
self._stop_recv = threading.Event()
|
||||||
|
|
||||||
|
def _recv_loop():
|
||||||
|
while not self._stop_recv.is_set():
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
||||||
|
# DEALER 收到的格式: [b'', json_bytes]
|
||||||
|
data_bytes = frames[-1]
|
||||||
|
message = json.loads(data_bytes.decode('utf-8'))
|
||||||
|
method = message.get('method')
|
||||||
|
params = message.get('params')
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print(f"[{ts}] [CmdClient] recv: {method}={params}")
|
||||||
|
if method == 'paradigm':
|
||||||
|
self._result_server.paradigm = params
|
||||||
|
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
|
||||||
|
elif method == 'result':
|
||||||
|
self._result_server.ChoosenNum = params
|
||||||
|
self._result_server.trial_idx += 1
|
||||||
|
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
|
||||||
|
elif method == 'energy':
|
||||||
|
self._result_server.energy = params
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.005)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[CmdClient recv] error: {e}")
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
|
||||||
|
self._recv_thread.start()
|
||||||
|
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
|
||||||
|
|
||||||
|
def stop_recv_thread(self):
|
||||||
|
if hasattr(self, '_stop_recv'):
|
||||||
|
self._stop_recv.set()
|
||||||
|
|
||||||
|
def _send_label(self, label_value):
|
||||||
|
"""向 datamock.py 发送标签命令"""
|
||||||
|
try:
|
||||||
|
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Client] label send error: {e}")
|
||||||
|
|
||||||
|
def send_data(self, method, params):
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
try:
|
||||||
|
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||||
|
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print(f"[{ts}] send_data: {method}={params}")
|
||||||
|
# 根据 train/predict 命令向 datamock 发送标签
|
||||||
|
if method == 'train':
|
||||||
|
if params == 0:
|
||||||
|
self._send_label(1)
|
||||||
|
print(f"[Label] train 0 -> datamock label=1")
|
||||||
|
elif params == 1:
|
||||||
|
self._send_label(2)
|
||||||
|
print(f"[Label] train 1 -> datamock label=2")
|
||||||
|
elif method == 'predict':
|
||||||
|
self._send_label(99)
|
||||||
|
print(f"[Label] predict -> datamock label=99")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Client] send error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 主流程 ==========
|
||||||
|
def run_headless():
|
||||||
|
server = ZmqResultServer(port=8088)
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
_dh = str(IniRead('system', 'Decoder_Host'))
|
||||||
|
_dp = int(IniRead('system', 'Decoder_Port'))
|
||||||
|
client = ZmqCmdClient(_dh, _dp)
|
||||||
|
client.connect()
|
||||||
|
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
|
||||||
|
|
||||||
|
time.sleep(1) # 等待连接建立
|
||||||
|
client.send_data('decoderClass', 'ssmvep')
|
||||||
|
|
||||||
|
train_time = 2.5 # 每轮训练刺激时长 (s)
|
||||||
|
test_time = 2.5 # 每轮测试刺激时长 (s)
|
||||||
|
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
|
||||||
|
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
|
||||||
|
rest_time = float(IniRead('system', 'Rest_time'))
|
||||||
|
|
||||||
|
num_blocks = int(IniRead('system', 'Num_blocks'))
|
||||||
|
num_trials = int(IniRead('system', 'Num_trials'))
|
||||||
|
|
||||||
|
position = [0, 1]
|
||||||
|
truePos_seq = position * int(num_trials / len(position))
|
||||||
|
truePos_seq = np.random.permutation(truePos_seq).tolist()
|
||||||
|
user_choice = []
|
||||||
|
|
||||||
|
os.makedirs('EEGFiles', exist_ok=True)
|
||||||
|
seq_file_path = f'EEGFiles/pos_seq_{personname}{session}_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||||
|
seq_info = {
|
||||||
|
'position': position,
|
||||||
|
'sequence': truePos_seq,
|
||||||
|
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
}
|
||||||
|
with open(seq_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(seq_info, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
trained = 0
|
||||||
|
Num_Total = 0
|
||||||
|
Num_Success = 0
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
print("[Headless] 开始运行 SSMVEP 通讯流程(无界面)")
|
||||||
|
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
|
||||||
|
print(f" train_time={train_time}s, test_time={test_time}s")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# -------- 个体校准阶段 --------
|
||||||
|
print("\n[Phase] 个体校准阶段 (paradigm=0)")
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# epoch完成需要的额外等待时间:train_latency=120包×20ms=2.4s
|
||||||
|
# 在train_time后需再等epoch_wait秒,decoder才能完成epoch采集并取出数据
|
||||||
|
epoch_wait = 2.4 # 秒,与train_latency对应
|
||||||
|
|
||||||
|
while server.paradigm == 0:
|
||||||
|
# 左腿刺激
|
||||||
|
print(f"\n[Train] 左腿刺激 (train 0) trained={trained}")
|
||||||
|
client.send_data('train', 0)
|
||||||
|
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(max(0, abs(fault_rehabilitation - train_time) - epoch_wait))
|
||||||
|
|
||||||
|
# 右腿刺激
|
||||||
|
print(f"\n[Train] 右腿刺激 (train 1) trained={trained}")
|
||||||
|
client.send_data('train', 1)
|
||||||
|
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(max(0, fault_rehabilitation - epoch_wait))
|
||||||
|
|
||||||
|
# 个体校准阶段结束
|
||||||
|
print("\n[Phase] 个体校准结束,等待 paradigm=1 ...")
|
||||||
|
trained = 0
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# -------- 康复训练阶段 --------
|
||||||
|
while server.paradigm == 1:
|
||||||
|
print("\n[Phase] 康复训练阶段 (paradigm=1)")
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
print(f"\n [Block {block_idx+1}/{num_blocks}]")
|
||||||
|
time.sleep(10) # 每轮开始前等待
|
||||||
|
|
||||||
|
for trial_idx in range(num_trials):
|
||||||
|
true_position = truePos_seq[trial_idx]
|
||||||
|
print(f" [Trial {trial_idx+1}/{num_trials}] true_pos={true_position}")
|
||||||
|
|
||||||
|
time.sleep(0.5) # 提示 + 叮声
|
||||||
|
server.ChoosenNum = -1
|
||||||
|
|
||||||
|
# 开始测试
|
||||||
|
# predict epoch latency = 115包×20ms = 2.3s,需额外等待epoch完成
|
||||||
|
predict_epoch_wait = 2.3 # 秒,与predict latency=115包对应
|
||||||
|
client.send_data('predict', 1)
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
|
||||||
|
if server.ChoosenNum >= 0:
|
||||||
|
Num_Total += 1
|
||||||
|
user_choice.append(server.ChoosenNum)
|
||||||
|
if server.ChoosenNum in [0, 1]:
|
||||||
|
Num_Success += 1
|
||||||
|
rest_time = right_rehabilitation
|
||||||
|
break
|
||||||
|
time.sleep(0.02)
|
||||||
|
|
||||||
|
trained += 1
|
||||||
|
client.send_data('rest', 0)
|
||||||
|
time.sleep(0.5)
|
||||||
|
time.sleep(rest_time)
|
||||||
|
server.ChoosenNum = -1
|
||||||
|
|
||||||
|
# 训练结束
|
||||||
|
print("\n[Phase] 康复训练结束")
|
||||||
|
break # 退出康复训练循环
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
|
||||||
|
expected_seq = truePos_seq * num_blocks
|
||||||
|
min_len = min(len(user_choice), len(expected_seq))
|
||||||
|
same_count = sum(1 for a, b in zip(user_choice[:min_len], expected_seq[:min_len]) if a == b)
|
||||||
|
true_accuracy = same_count / min_len if min_len > 0 else 0
|
||||||
|
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
|
||||||
|
print(f"[Result] TrueAcc={true_accuracy:.3f} ({same_count}/{min_len})")
|
||||||
|
break # 完成一个完整流程后退出
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n[Headless] 用户中断")
|
||||||
|
finally:
|
||||||
|
client.send_data('predict', 2) # 关闭系统
|
||||||
|
client.send_data('saveData', 0)
|
||||||
|
server.stop()
|
||||||
|
print("[Headless] 已发送关闭指令,退出。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_headless()
|
||||||
364
upperHost_stimmock/ssvep_main.py
Normal file
364
upperHost_stimmock/ssvep_main.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from psychopy import visual, core, logging # import some libraries from PsychoPy
|
||||||
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# LAB STREAMING LAYER1
|
||||||
|
from pylsl import StreamInfo, StreamOutlet
|
||||||
|
from psychopy import event
|
||||||
|
import numpy as np
|
||||||
|
from DecoderDW.Server import TCPServer
|
||||||
|
from DecoderDW.Client import TCPClient
|
||||||
|
# import subprocess
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# constants
|
||||||
|
# size of the window
|
||||||
|
WINWIDTH = 1920
|
||||||
|
WINHEIGHT = 1080
|
||||||
|
REFRESH_RATE = 144
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_keypress():
|
||||||
|
keys = event.getKeys()
|
||||||
|
if keys:
|
||||||
|
return keys[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def shutdown(win,client):
|
||||||
|
client.send_data('saveData', 0)
|
||||||
|
client.send_data('predict',2)
|
||||||
|
win.close()
|
||||||
|
core.quit()
|
||||||
|
|
||||||
|
|
||||||
|
# end of configuration
|
||||||
|
# ----------------------
|
||||||
|
|
||||||
|
def generate_square_wave(frequency, sampling_rate=REFRESH_RATE, duration=5):
|
||||||
|
"""
|
||||||
|
生成方波序列
|
||||||
|
|
||||||
|
参数:
|
||||||
|
frequency (float): 频率(Hz)
|
||||||
|
sampling_rate (int): 采样率(Hz),应与屏幕刷新率一致
|
||||||
|
duration (float): 时长(秒)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
square_wave (list): 方波序列
|
||||||
|
"""
|
||||||
|
# 计算总点数
|
||||||
|
n_points = int(duration * sampling_rate)
|
||||||
|
|
||||||
|
# 生成时间序列
|
||||||
|
time = np.linspace(0, duration, n_points, endpoint=False)
|
||||||
|
|
||||||
|
# 生成正弦波数据
|
||||||
|
sin_wave = np.sin(2 * np.pi * frequency * time)
|
||||||
|
# 生成方波数据
|
||||||
|
square_wave = np.where(sin_wave >= 0, 1, 0)
|
||||||
|
|
||||||
|
return square_wave.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
# 启动一个进程,不等待其完成
|
||||||
|
import os
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# ----------------------------------------------------------------------------------
|
||||||
|
# main window settings
|
||||||
|
main_win = visual.Window(size=(WINWIDTH, WINHEIGHT), units='height', screen=0, fullscr=False,
|
||||||
|
gammaErrorPolicy='warn', color=(0.7, 0.7, 0.7))
|
||||||
|
print('starting 1')
|
||||||
|
# Set up LabStreamingLayer stream.
|
||||||
|
info = StreamInfo(name='psychopy_stimuli', type='Markers', channel_count=1, channel_format='string',
|
||||||
|
source_id='psychopy_stimuli_001')
|
||||||
|
outlet = StreamOutlet(info) # Broadcast the stream.
|
||||||
|
|
||||||
|
imageStim1 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim1 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(-600, 30))
|
||||||
|
|
||||||
|
imageStim2 = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim2 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(0, 30))
|
||||||
|
|
||||||
|
imageStim3 = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim3 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(600, 30))
|
||||||
|
imageStim4 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim4 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(-600, -470))
|
||||||
|
imageStim5 = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim5 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(0, -470))
|
||||||
|
imageStim6 = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy.jpg')
|
||||||
|
txtStim6 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True,
|
||||||
|
italic=False, pos=(600, -470))
|
||||||
|
imageStim1red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
imageStim2red = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
imageStim3red = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
imageStim4red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
imageStim5red = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
imageStim6red = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy_red.jpg')
|
||||||
|
|
||||||
|
|
||||||
|
frequencies = [25,26,27,28,29,30] #[9,10,11,12,13,14] #[30,31,32,33,34,35] [25,26,27,28,29,30]
|
||||||
|
# 生成方波数据
|
||||||
|
square_wave_9 = generate_square_wave(frequencies[0], REFRESH_RATE, 5)
|
||||||
|
square_wave_11 = generate_square_wave(frequencies[1], REFRESH_RATE, 5)
|
||||||
|
square_wave_12 = generate_square_wave(frequencies[2], REFRESH_RATE, 5)
|
||||||
|
square_wave_13 = generate_square_wave(frequencies[3], REFRESH_RATE, 5)
|
||||||
|
square_wave_14 = generate_square_wave(frequencies[4], REFRESH_RATE, 5)
|
||||||
|
square_wave_15 = generate_square_wave(frequencies[5], REFRESH_RATE, 5)
|
||||||
|
|
||||||
|
# 创建刺激对象列表,便于管理
|
||||||
|
image_stims = [imageStim1, imageStim2, imageStim3, imageStim4, imageStim5, imageStim6]
|
||||||
|
txt_stims = [txtStim1, txtStim2, txtStim3, txtStim4, txtStim5, txtStim6]
|
||||||
|
square_waves = [square_wave_9, square_wave_11, square_wave_12, square_wave_13, square_wave_14, square_wave_15]
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
# grating.color = 'black'
|
||||||
|
server = TCPServer()
|
||||||
|
server.start()
|
||||||
|
client = TCPClient('127.0.0.1', 8099)
|
||||||
|
client.connect()
|
||||||
|
print('Connected decoder_main')
|
||||||
|
# client.send_data('impedance', 1)
|
||||||
|
# time.sleep(20)
|
||||||
|
# client.send_data('impedance', 2)
|
||||||
|
client.send_data('targetFreqs', frequencies) # 使用frequencies变量,确保与刺激频率一致
|
||||||
|
time.sleep(1)
|
||||||
|
# 开启全程数据保存到 EEGFiles
|
||||||
|
client.send_data('saveData',1)
|
||||||
|
# client.send_data('impedance',1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 实验参数
|
||||||
|
repeats = 3
|
||||||
|
seq_freq = frequencies * repeats
|
||||||
|
seq_freq = np.random.permutation(seq_freq).tolist()
|
||||||
|
num_trials = len(seq_freq) # 总试验次数, 6*6=36
|
||||||
|
trial_count = 0
|
||||||
|
|
||||||
|
# 在线解码精度计算
|
||||||
|
online_results = [] # 存储每个trial的解码结果
|
||||||
|
correct_predictions = 0 # 正确预测计数
|
||||||
|
|
||||||
|
# 保存序列信息
|
||||||
|
seq_info = {
|
||||||
|
'total_trials': num_trials,
|
||||||
|
'frequencies': frequencies,
|
||||||
|
'sequence': seq_freq,
|
||||||
|
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
}
|
||||||
|
# 保存序列信息到文件
|
||||||
|
import json
|
||||||
|
seq_file_path = f'EEGFiles/sequence_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||||
|
with open(seq_file_path, 'a', encoding='utf-8') as f:
|
||||||
|
json.dump(seq_info, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
#========================Trials Started======================#
|
||||||
|
while trial_count < num_trials:
|
||||||
|
# 从序列中获取当前试验的目标频率
|
||||||
|
target_freq = seq_freq[trial_count]
|
||||||
|
target_freq_index = frequencies.index(target_freq)
|
||||||
|
print(f'Trials {trial_count + 1}/{num_trials} - Target Frequency: {target_freq}Hz (Label: {target_freq_index + 1})')
|
||||||
|
|
||||||
|
# Stage 1: Cue Stage
|
||||||
|
# print('Cue Stage: The target frequency is in Red')
|
||||||
|
client.send_data('setLabelAndTrialInfo', {
|
||||||
|
'label': 0,
|
||||||
|
'trial_info': {
|
||||||
|
'trial': trial_count + 1,
|
||||||
|
'phase': 'cue',
|
||||||
|
'target_freq': target_freq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
for frameN in range(int(1 * REFRESH_RATE)): # 1秒提示
|
||||||
|
key_press = get_keypress()
|
||||||
|
if key_press in ['q']:
|
||||||
|
shutdown(main_win, client)
|
||||||
|
|
||||||
|
# 显示所有刺激,目标刺激为红色
|
||||||
|
for i, stim in enumerate(image_stims):
|
||||||
|
if i == target_freq_index:
|
||||||
|
# 目标刺激显示红色
|
||||||
|
if i == 0:
|
||||||
|
imageStim1red.draw()
|
||||||
|
elif i == 1:
|
||||||
|
imageStim2red.draw()
|
||||||
|
elif i == 2:
|
||||||
|
imageStim3red.draw()
|
||||||
|
elif i == 3:
|
||||||
|
imageStim4red.draw()
|
||||||
|
elif i == 4:
|
||||||
|
imageStim5red.draw()
|
||||||
|
elif i == 5:
|
||||||
|
imageStim6red.draw()
|
||||||
|
else:
|
||||||
|
# 其他刺激显示正常颜色
|
||||||
|
stim.draw()
|
||||||
|
|
||||||
|
main_win.flip()
|
||||||
|
|
||||||
|
# Stage 2: Flanker Stimulus
|
||||||
|
# print('Flanker Stage: flank all frequencies')
|
||||||
|
client.send_data('predict', 1)
|
||||||
|
client.send_data('setLabelAndTrialInfo', {
|
||||||
|
'label': target_freq_index + 1, # 设置目标频率标签 这里+1,是因为0代表不记录数据
|
||||||
|
'trial_info': {
|
||||||
|
'trial': trial_count + 1, # trial 从0开始
|
||||||
|
'phase': 'stimulus',
|
||||||
|
'target_freq': target_freq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
outlet.push_sample(['S 1'])
|
||||||
|
|
||||||
|
for frameN in range(6 * REFRESH_RATE): # 6秒刺激
|
||||||
|
key_press = get_keypress()
|
||||||
|
if key_press in ['q']:
|
||||||
|
shutdown(main_win, client)
|
||||||
|
|
||||||
|
# 所有频率按照方波闪烁
|
||||||
|
if square_wave_9[frameN % len(square_wave_9)] == 1:
|
||||||
|
imageStim1.draw()
|
||||||
|
if square_wave_11[frameN % len(square_wave_11)] == 1:
|
||||||
|
imageStim2.draw()
|
||||||
|
if square_wave_12[frameN % len(square_wave_12)] == 1:
|
||||||
|
imageStim3.draw()
|
||||||
|
if square_wave_13[frameN % len(square_wave_13)] == 1:
|
||||||
|
imageStim4.draw()
|
||||||
|
if square_wave_14[frameN % len(square_wave_14)] == 1:
|
||||||
|
imageStim5.draw()
|
||||||
|
if square_wave_15[frameN % len(square_wave_15)] == 1:
|
||||||
|
imageStim6.draw()
|
||||||
|
|
||||||
|
main_win.flip()
|
||||||
|
if server.ChoosenNum != -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 记录在线解码结果
|
||||||
|
predicted_freq_index = server.ChoosenNum # 解码结果
|
||||||
|
predicted_freq = frequencies[predicted_freq_index] if predicted_freq_index != -1 else -1
|
||||||
|
|
||||||
|
# 判断解码是否正确
|
||||||
|
is_correct = (predicted_freq_index == target_freq_index) if predicted_freq_index != -1 else False
|
||||||
|
if is_correct:
|
||||||
|
correct_predictions += 1
|
||||||
|
|
||||||
|
# 记录trial结果
|
||||||
|
trial_result = {
|
||||||
|
'trial': trial_count + 1,
|
||||||
|
'target_freq': target_freq,
|
||||||
|
'target_freq_index': target_freq_index,
|
||||||
|
'predicted_freq': predicted_freq,
|
||||||
|
'predicted_freq_index': predicted_freq_index,
|
||||||
|
'is_correct': is_correct,
|
||||||
|
'status': 'Success' if predicted_freq_index != -1 else 'Failed'
|
||||||
|
}
|
||||||
|
online_results.append(trial_result)
|
||||||
|
|
||||||
|
# 打印当前trial结果
|
||||||
|
status_symbol = "✓" if is_correct else "✗"
|
||||||
|
if predicted_freq_index == -1:
|
||||||
|
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 解码失败 - {status_symbol}')
|
||||||
|
else:
|
||||||
|
print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 预测{predicted_freq}Hz - {status_symbol}')
|
||||||
|
|
||||||
|
|
||||||
|
# Stage 3: Decoding Feedback
|
||||||
|
outlet.push_sample(['S 2'])
|
||||||
|
client.send_data('setLabelAndTrialInfo', {
|
||||||
|
'label': 0, # 反馈阶段标签为0
|
||||||
|
'trial_info': {
|
||||||
|
'trial': trial_count + 1,
|
||||||
|
'phase': 'feedback',
|
||||||
|
'target_freq': target_freq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
# print('反馈阶段: 显示解码结果')
|
||||||
|
|
||||||
|
for frameN in range(1 * REFRESH_RATE): # 1秒反馈
|
||||||
|
key_press = get_keypress()
|
||||||
|
if key_press in ['q']:
|
||||||
|
shutdown(main_win, client)
|
||||||
|
|
||||||
|
# 显示所有刺激但不闪烁
|
||||||
|
for stim in image_stims:
|
||||||
|
stim.draw()
|
||||||
|
|
||||||
|
# 显示解码结果
|
||||||
|
if server.ChoosenNum == 0:
|
||||||
|
txtStim1.draw()
|
||||||
|
elif server.ChoosenNum == 1:
|
||||||
|
txtStim2.draw()
|
||||||
|
elif server.ChoosenNum == 2:
|
||||||
|
txtStim3.draw()
|
||||||
|
elif server.ChoosenNum == 3:
|
||||||
|
txtStim4.draw()
|
||||||
|
elif server.ChoosenNum == 4:
|
||||||
|
txtStim5.draw()
|
||||||
|
elif server.ChoosenNum == 5:
|
||||||
|
txtStim6.draw()
|
||||||
|
|
||||||
|
main_win.flip()
|
||||||
|
|
||||||
|
server.ChoosenNum = -1
|
||||||
|
trial_count += 1
|
||||||
|
|
||||||
|
# 计算总体在线解码精度
|
||||||
|
total_trials = len(online_results)
|
||||||
|
successful_trials = len([r for r in online_results if r['status'] == 'Success'])
|
||||||
|
failed_trials = len([r for r in online_results if r['status'] == 'Failed'])
|
||||||
|
overall_accuracy = correct_predictions / total_trials if total_trials > 0 else 0
|
||||||
|
|
||||||
|
# Print Accuracy
|
||||||
|
print(f"Total Accuracy: {overall_accuracy:.3f} ({correct_predictions}/{total_trials})")
|
||||||
|
|
||||||
|
# 按频率分析准确率
|
||||||
|
print(f"\n=== 按频率分析准确率 ===")
|
||||||
|
freq_accuracy = {}
|
||||||
|
for result in online_results:
|
||||||
|
freq = result['target_freq']
|
||||||
|
if freq not in freq_accuracy:
|
||||||
|
freq_accuracy[freq] = {'correct': 0, 'total': 0, 'failed': 0}
|
||||||
|
|
||||||
|
freq_accuracy[freq]['total'] += 1
|
||||||
|
if result['status'] == 'Failed':
|
||||||
|
freq_accuracy[freq]['failed'] += 1
|
||||||
|
elif result['is_correct']:
|
||||||
|
freq_accuracy[freq]['correct'] += 1
|
||||||
|
|
||||||
|
print(f"{'频率':<8} {'准确率':<8} {'正确/总数':<10} {'失败数':<8}")
|
||||||
|
print("-" * 40)
|
||||||
|
for freq in sorted(freq_accuracy.keys()):
|
||||||
|
stats = freq_accuracy[freq]
|
||||||
|
accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
|
||||||
|
print(f"{freq}Hz{'':<4} {accuracy:.3f}{'':<4} {stats['correct']}/{stats['total']}{'':<6} {stats['failed']}")
|
||||||
|
|
||||||
|
# 保存在线解码结果到文件
|
||||||
|
online_results_file = f'EEGFiles/online_results_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
||||||
|
online_summary = {
|
||||||
|
'total_trials': total_trials,
|
||||||
|
'successful_trials': successful_trials,
|
||||||
|
'failed_trials': failed_trials,
|
||||||
|
'correct_predictions': correct_predictions,
|
||||||
|
'overall_accuracy': overall_accuracy,
|
||||||
|
# 'freq_accuracy': freq_accuracy,
|
||||||
|
'trial_results': online_results,
|
||||||
|
# 'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(online_results_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(online_summary, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
client.send_data('predict',2) # 关闭系统
|
||||||
|
main_win.close()
|
||||||
304
verify_datamock.py
Normal file
304
verify_datamock.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""
|
||||||
|
datamock 验证脚本(模拟算法端)
|
||||||
|
作为 ZMQ ROUTER 监听 8100 端口,等待 datamock.py 连接并验证数据流
|
||||||
|
|
||||||
|
运行顺序:
|
||||||
|
第一步: python verify_datamock.py (先启动,监听 8100)
|
||||||
|
第二步: python datamock.py (后启动,连接 8100)
|
||||||
|
"""
|
||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('TkAgg')
|
||||||
|
|
||||||
|
# 在导入 pyplot 之前确保 Tkinter 正确初始化
|
||||||
|
try:
|
||||||
|
import tkinter as tk
|
||||||
|
root = tk.Tk()
|
||||||
|
root.withdraw() # 隐藏主窗口,我们只需要它的事件循环
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Tkinter 初始化警告: {e}")
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# ===== 可视化参数 =====
|
||||||
|
PLOT_WINDOW_SEC = 2.0 # 滑动窗口时长(秒)
|
||||||
|
PLOT_CHANNELS = [0, 1, 2, 3] # 要显示的 EEG 通道索引
|
||||||
|
|
||||||
|
SERVER_ADDR = 'tcp://127.0.0.1:8100'
|
||||||
|
FS = 250
|
||||||
|
N_SAMPLES_PER_PKT = 5
|
||||||
|
N_CHAN = 66
|
||||||
|
EEG_FREQ = 10
|
||||||
|
EEG_AMP = 100.0 # EEG 幅值 100μV(峰值)
|
||||||
|
EEG_AMP_MEAN = EEG_AMP * 2 / np.pi # 正弦波 |mean| ≈ 63.7μV
|
||||||
|
EEG_AMP_TOLERANCE = 1.5 # 幅值容差倍数
|
||||||
|
LABEL_INTERVAL = 5
|
||||||
|
FFT_SAMPLES = 250 # 做一次 FFT 需要的采样点数(1s数据)
|
||||||
|
EXPECTED_BYTES = N_SAMPLES_PER_PKT * N_CHAN * 4 # 1320 bytes (5*66*4)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_fft(samples):
|
||||||
|
"""对 Ch0 数据做 FFT,返回峰值频率"""
|
||||||
|
freqs = np.fft.rfftfreq(FFT_SAMPLES, d=1 / FS)
|
||||||
|
fft_mag = np.abs(np.fft.rfft(samples))
|
||||||
|
peak_idx = np.argmax(fft_mag[1:]) + 1 # 跳过 DC
|
||||||
|
return freqs[peak_idx], fft_mag, freqs
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ctx = zmq.Context()
|
||||||
|
sock = ctx.socket(zmq.ROUTER)
|
||||||
|
sock.bind(SERVER_ADDR)
|
||||||
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ ROUTER 绑定 {SERVER_ADDR},等待 datamock.py 连接...\n")
|
||||||
|
|
||||||
|
# ===== 初始化交互式绘图 =====
|
||||||
|
plt.ion() # 开启交互模式
|
||||||
|
fig = plt.figure(figsize=(14, 10))
|
||||||
|
fig.suptitle('EEG Data Monitor (Real-time)', fontsize=14)
|
||||||
|
|
||||||
|
# 使用 GridSpec 进行布局
|
||||||
|
from matplotlib.gridspec import GridSpec
|
||||||
|
gs = GridSpec(len(PLOT_CHANNELS) + 2, 1, figure=fig, hspace=0.3)
|
||||||
|
axes = []
|
||||||
|
lines_eeg = []
|
||||||
|
for i, ch in enumerate(PLOT_CHANNELS):
|
||||||
|
ax = fig.add_subplot(gs[i])
|
||||||
|
axes.append(ax)
|
||||||
|
ax.set_ylabel(f'Ch{ch} (μV)', fontsize=8)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_ylim(-150, 150)
|
||||||
|
line, = ax.plot([], [], lw=0.8)
|
||||||
|
lines_eeg.append(line)
|
||||||
|
ax.set_title(f'EEG Channel {ch}', fontsize=9)
|
||||||
|
|
||||||
|
# 标签通道子图 (Ch64 - 标签值)
|
||||||
|
ax_label = fig.add_subplot(gs[len(PLOT_CHANNELS)])
|
||||||
|
axes.append(ax_label)
|
||||||
|
ax_label.set_ylabel('Label Value', fontsize=8)
|
||||||
|
ax_label.grid(True, alpha=0.3)
|
||||||
|
ax_label.set_ylim(-0.5, 2.5)
|
||||||
|
line_label, = ax_label.plot([], [], 'ro-', lw=1.5, markersize=4)
|
||||||
|
line_label_data = line_label
|
||||||
|
ax_label.set_title('Ch64 - Label Value', fontsize=9)
|
||||||
|
|
||||||
|
# Ch65 标签序号子图
|
||||||
|
ax_seq = fig.add_subplot(gs[len(PLOT_CHANNELS) + 1])
|
||||||
|
axes.append(ax_seq)
|
||||||
|
ax_seq.set_ylabel('Label Seq', fontsize=8)
|
||||||
|
ax_seq.set_xlabel('Time (samples)', fontsize=8)
|
||||||
|
ax_seq.grid(True, alpha=0.3)
|
||||||
|
ax_seq.set_ylim(-0.5, 10)
|
||||||
|
line_seq, = ax_seq.plot([], [], 'gs-', lw=1.5, markersize=4)
|
||||||
|
line_seq_data = line_seq
|
||||||
|
ax_seq.set_title('Ch65 - Label Sequence', fontsize=9)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# ===== 状态 =====
|
||||||
|
global_idx = 0 # 全局采样点索引
|
||||||
|
label_events = [] # 捕获的标签事件
|
||||||
|
start_time = None
|
||||||
|
fft_done = False
|
||||||
|
fft_buffer = [] # 暂存前 250 点做 FFT
|
||||||
|
ch64_zero_ok = True # 验证 Ch64 非标签采样点均为 0
|
||||||
|
ch65_zero_ok = True # 验证 Ch65 非标签采样点均为 0
|
||||||
|
label_pos_ok_all = True # 验证标签均在包内索引 4
|
||||||
|
|
||||||
|
# ===== 数据缓冲区 =====
|
||||||
|
max_samples = int(FS * PLOT_WINDOW_SEC)
|
||||||
|
eeg_buffer = {ch: np.zeros(max_samples) for ch in PLOT_CHANNELS}
|
||||||
|
label_buffer = np.zeros(max_samples)
|
||||||
|
seq_buffer = np.zeros(max_samples)
|
||||||
|
time_axis = np.arange(max_samples)
|
||||||
|
|
||||||
|
# ZMQ 收发统计
|
||||||
|
recv_count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 首次 pause 用于显示窗口
|
||||||
|
plt.pause(0.5)
|
||||||
|
print(f"[INFO] 交互窗口已显示,如未看到请检查任务栏")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# ROUTER recv: prepended 一个 identity 帧
|
||||||
|
# datamock 发送 3帧 [b'datamock', b'', data_bytes]
|
||||||
|
# ROUTER 接收后变成 4帧 [router_identity, b'datamock', b'', data_bytes]
|
||||||
|
frames = sock.recv_multipart()
|
||||||
|
recv_count += 1
|
||||||
|
now = time.time()
|
||||||
|
if start_time is None:
|
||||||
|
start_time = now
|
||||||
|
|
||||||
|
# 帧格式: [router_identity, b'datamock', b'', data_bytes]
|
||||||
|
router_id = frames[0] # ROUTER 添加的身份帧
|
||||||
|
identity = frames[1] # 发送端的 identity
|
||||||
|
_empty = frames[2] # 空帧
|
||||||
|
raw_data = frames[3] # 实际数据字节
|
||||||
|
|
||||||
|
# 数据长度校验
|
||||||
|
if len(raw_data) != EXPECTED_BYTES:
|
||||||
|
print(f"[ERROR] 数据长度错误: 期望{EXPECTED_BYTES}字节, 实际{len(raw_data)}字节")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解析为 [5, 66] float32 数组
|
||||||
|
packet = np.frombuffer(raw_data, dtype=np.float32).reshape(N_SAMPLES_PER_PKT, N_CHAN)
|
||||||
|
|
||||||
|
elapsed = now - start_time
|
||||||
|
|
||||||
|
# ===== 验证 1: 数据形状 =====
|
||||||
|
if recv_count == 1:
|
||||||
|
shape_ok = packet.shape == (N_SAMPLES_PER_PKT, N_CHAN)
|
||||||
|
print(f"[{'✓' if shape_ok else '✗'}] 数据形状: {packet.shape} "
|
||||||
|
f"(期望 [{N_SAMPLES_PER_PKT}, {N_CHAN}])")
|
||||||
|
if not shape_ok:
|
||||||
|
print(f" ✗ 形状不匹配,退出")
|
||||||
|
break
|
||||||
|
|
||||||
|
# ===== 验证 2: EEG 幅值(首包) =====
|
||||||
|
if recv_count == 1:
|
||||||
|
eeg = packet[:, :64]
|
||||||
|
amp_mean = np.mean(np.abs(eeg))
|
||||||
|
amp_ok = amp_mean <= EEG_AMP_MEAN * EEG_AMP_TOLERANCE
|
||||||
|
print(f"[{'✓' if amp_ok else '✗'}] EEG 幅值: 均值={amp_mean:.2f}μV "
|
||||||
|
f"(期望 ~{EEG_AMP_MEAN:.2f}μV,峰值 ~{EEG_AMP:.2f}μV)")
|
||||||
|
if not amp_ok:
|
||||||
|
print(f" ✗ 幅值超出容差范围")
|
||||||
|
|
||||||
|
# ===== 验证 3: EEG 频率(首秒数据收集满后做 FFT) =====
|
||||||
|
fft_buffer.append(packet[:, 0].copy()) # 收集 Ch0
|
||||||
|
|
||||||
|
if not fft_done and len(fft_buffer) * N_SAMPLES_PER_PKT >= FFT_SAMPLES:
|
||||||
|
# 凑够 250 点,做 FFT
|
||||||
|
all_ch0 = np.concatenate(fft_buffer)[:FFT_SAMPLES]
|
||||||
|
peak_freq, fft_mag, freqs = validate_fft(all_ch0)
|
||||||
|
freq_ok = abs(peak_freq - EEG_FREQ) < 1.0
|
||||||
|
|
||||||
|
print(f"[{'✓' if freq_ok else '✗'}] EEG 频率: 峰值={peak_freq:.1f}Hz "
|
||||||
|
f"(期望 ~{EEG_FREQ}Hz)")
|
||||||
|
print(f" FFT 幅度谱前 5 峰值:")
|
||||||
|
top5 = np.argsort(fft_mag[1:])[-5:][::-1] + 1
|
||||||
|
for rank, idx in enumerate(top5):
|
||||||
|
print(f" {rank+1}. {freqs[idx]:.1f}Hz 幅度={fft_mag[idx]:.1f}")
|
||||||
|
print()
|
||||||
|
fft_done = True
|
||||||
|
|
||||||
|
# ===== 验证 4: 标签通道(Ch64/Ch65) =====
|
||||||
|
ch64 = packet[:, 64]
|
||||||
|
ch65 = packet[:, 65]
|
||||||
|
ch64_nonzero = np.where(ch64 != 0)[0]
|
||||||
|
ch65_nonzero = np.where(ch65 != 0)[0]
|
||||||
|
|
||||||
|
# 检查非标签采样点是否全为 0
|
||||||
|
ch64_zeros = np.all(ch64[:4] == 0)
|
||||||
|
ch65_zeros = np.all(ch65[:4] == 0)
|
||||||
|
ch64_zero_ok = ch64_zero_ok and ch64_zeros
|
||||||
|
ch65_zero_ok = ch65_zero_ok and ch65_zeros
|
||||||
|
|
||||||
|
if len(ch64_nonzero) > 0:
|
||||||
|
pos_in_pkt = int(ch64_nonzero[0])
|
||||||
|
label_val = int(ch64[pos_in_pkt])
|
||||||
|
label_seq = int(ch65[pos_in_pkt])
|
||||||
|
|
||||||
|
pos_ok = (len(ch64_nonzero) == 1 and pos_in_pkt == 4)
|
||||||
|
label_pos_ok_all = label_pos_ok_all and pos_ok
|
||||||
|
|
||||||
|
elapsed_since_start = now - start_time
|
||||||
|
print(f"[✓] 标签触发 @ {elapsed_since_start:.1f}s "
|
||||||
|
f"(global_idx={global_idx} 包{recv_count})")
|
||||||
|
print(f" Ch64 标签值: {label_val} Ch65 序号: {label_seq}")
|
||||||
|
print(f" 包内位置: 采样点 {pos_in_pkt}/4 "
|
||||||
|
f"({'✓' if pos_ok else '✗ 期望 4'}) "
|
||||||
|
f"其余采样点 Ch64=0: {'✓' if ch64_zeros else '✗'} "
|
||||||
|
f"Ch65=0: {'✓' if ch65_zeros else '✗'}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
label_events.append({
|
||||||
|
'time': elapsed_since_start,
|
||||||
|
'label': label_val,
|
||||||
|
'seq': label_seq
|
||||||
|
})
|
||||||
|
|
||||||
|
global_idx += N_SAMPLES_PER_PKT
|
||||||
|
|
||||||
|
# ===== 更新绘图缓冲区 =====
|
||||||
|
for ch_idx, ch in enumerate(PLOT_CHANNELS):
|
||||||
|
eeg_buffer[ch] = np.roll(eeg_buffer[ch], -N_SAMPLES_PER_PKT)
|
||||||
|
eeg_buffer[ch][-N_SAMPLES_PER_PKT:] = packet[:, ch]
|
||||||
|
|
||||||
|
label_buffer = np.roll(label_buffer, -N_SAMPLES_PER_PKT)
|
||||||
|
label_buffer[-N_SAMPLES_PER_PKT:] = packet[:, 64]
|
||||||
|
|
||||||
|
seq_buffer = np.roll(seq_buffer, -N_SAMPLES_PER_PKT)
|
||||||
|
seq_buffer[-N_SAMPLES_PER_PKT:] = packet[:, 65]
|
||||||
|
|
||||||
|
# ===== 实时更新绘图 =====
|
||||||
|
for i, ch in enumerate(PLOT_CHANNELS):
|
||||||
|
lines_eeg[i].set_data(time_axis, eeg_buffer[ch]) # 数据已是 μV 单位
|
||||||
|
line_label_data.set_data(time_axis, label_buffer)
|
||||||
|
line_seq_data.set_data(time_axis, seq_buffer)
|
||||||
|
|
||||||
|
# 设置 x 轴范围
|
||||||
|
for ax in axes:
|
||||||
|
ax.set_xlim(0, max_samples)
|
||||||
|
|
||||||
|
# 刷新图形(交互模式)
|
||||||
|
fig.canvas.draw_idle()
|
||||||
|
plt.pause(0.001)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n" + "=" * 55)
|
||||||
|
print(" 验证结果汇总")
|
||||||
|
print("=" * 55)
|
||||||
|
print(f" 运行时长: {time.time() - start_time:.1f}s")
|
||||||
|
print(f" 收到包数: {recv_count}")
|
||||||
|
print(f" FFT 验证: {'✓ 已完成' if fft_done else '✗ 未完成(时长不足1s)'}")
|
||||||
|
print(f" 非标签采样点 Ch64=0: {'✓' if ch64_zero_ok else '✗'}")
|
||||||
|
print(f" 非标签采样点 Ch65=0: {'✓' if ch65_zero_ok else '✗'}")
|
||||||
|
print(f" 标签均在包内位置4: {'✓' if label_pos_ok_all else '✗'}")
|
||||||
|
|
||||||
|
if label_events:
|
||||||
|
print(f"\n 共捕获 {len(label_events)} 次标签事件:")
|
||||||
|
for i, ev in enumerate(label_events):
|
||||||
|
print(f" {i+1}. t={ev['time']:.1f}s label={ev['label']} 序号={ev['seq']}")
|
||||||
|
|
||||||
|
# 标签间隔
|
||||||
|
print(f"\n 标签间隔验证 (期望 ~{LABEL_INTERVAL}s):")
|
||||||
|
for i in range(1, len(label_events)):
|
||||||
|
dt = label_events[i]['time'] - label_events[i-1]['time']
|
||||||
|
ok = abs(dt - LABEL_INTERVAL) < 0.1
|
||||||
|
print(f" {i}->{i+1}: {dt:.2f}s {'✓' if ok else '✗'}")
|
||||||
|
|
||||||
|
# 标签交替
|
||||||
|
labels = [e['label'] for e in label_events]
|
||||||
|
alt_ok = all(labels[i] != labels[i+1] for i in range(len(labels) - 1))
|
||||||
|
print(f"\n 标签交替: {labels} {'✓ 交替正确' if alt_ok else '✗ 交替错误'}")
|
||||||
|
|
||||||
|
# 序号
|
||||||
|
label1_seqs = [e['seq'] for e in label_events if e['label'] == 1]
|
||||||
|
label2_seqs = [e['seq'] for e in label_events if e['label'] == 2]
|
||||||
|
s1_ok = label1_seqs == list(range(1, len(label1_seqs) + 1))
|
||||||
|
s2_ok = label2_seqs == list(range(1, len(label2_seqs) + 1))
|
||||||
|
print(f" label=1 序号: {label1_seqs} {'✓' if s1_ok else '✗'}")
|
||||||
|
print(f" label=2 序号: {label2_seqs} {'✓' if s2_ok else '✗'}")
|
||||||
|
else:
|
||||||
|
print(f"\n 未捕获标签事件(运行时长不足 {LABEL_INTERVAL}s)")
|
||||||
|
|
||||||
|
print("=" * 55)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
ctx.term()
|
||||||
|
plt.ioff()
|
||||||
|
plt.close('all')
|
||||||
|
try:
|
||||||
|
root.destroy()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user