Files
bci_algo/upperHost_stimmock/ssmvep_headless.py

302 lines
12 KiB
Python
Raw Permalink Normal View History

2026-06-10 09:19:31 +08:00
"""
ssmvep_headless.py
无界面版 SSMVEP 范式通讯流程模拟脚本
复现 ssmvep_main.py 的完整指令序列train 0/1/2, rest, predict, saveData
但不依赖 psychopy 也不打开任何窗口/音频 time.sleep 替代帧循环等待
启动顺序:
1. runDecoder.py
2. datamock.py
3. ssmvep_headless.py
"""
import sys
import os
import json
import time
import threading
import zmq
import numpy as np
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from PubLibrary.InifileHelper import IniRead
personname = 'demo'
session = '01'
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
# ========== ZMQ 结果接收服务 ==========
class ZmqResultServer(threading.Thread):
def __init__(self, port=8088):
threading.Thread.__init__(self)
self.port = port
self.running = True
self.energy = 0
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
self.ChoosenNum = -1
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
self.daemon = True
self.trial_idx = 0
def run(self):
print(f"[Server] UpperHost_Server listening on {self.port}")
while self.running:
try:
frames = self.socket.recv_multipart(zmq.NOBLOCK)
if len(frames) < 3:
continue
message = json.loads(frames[2].decode('utf-8'))
method = message.get('method')
params = message.get('params')
if method == 'energy':
self.energy = params
elif method == 'paradigm':
self.paradigm = params
print(f"[Server] paradigm -> {params}")
elif method == 'result':
self.ChoosenNum = params
self.trial_idx += 1
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[Server] error: {e}")
def stop(self):
self.running = False
self.socket.close()
self.context.term()
# ========== ZMQ 命令发送客户端 ==========
class ZmqCmdClient:
def __init__(self, host, port):
self.host = host
self.port = port
self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER)
# PUSH socket 用于向 datamock.py 发送标签命令
self._label_sock = self.context.socket(zmq.PUSH)
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
def connect(self):
self.socket.connect(f"tcp://{self.host}:{self.port}")
print(f"[Client] connected to {self.host}:{self.port}")
def start_recv_thread(self, result_server):
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
self._result_server = result_server
self._stop_recv = threading.Event()
def _recv_loop():
while not self._stop_recv.is_set():
try:
frames = self.socket.recv_multipart(zmq.NOBLOCK)
# DEALER 收到的格式: [b'', json_bytes]
data_bytes = frames[-1]
message = json.loads(data_bytes.decode('utf-8'))
method = message.get('method')
params = message.get('params')
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
print(f"[{ts}] [CmdClient] recv: {method}={params}")
if method == 'paradigm':
self._result_server.paradigm = params
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
elif method == 'result':
self._result_server.ChoosenNum = params
self._result_server.trial_idx += 1
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
elif method == 'energy':
self._result_server.energy = params
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[CmdClient recv] error: {e}")
time.sleep(0.01)
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
self._recv_thread.start()
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
def stop_recv_thread(self):
if hasattr(self, '_stop_recv'):
self._stop_recv.set()
def _send_label(self, label_value):
"""向 datamock.py 发送标签命令"""
try:
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
except Exception as e:
print(f"[Client] label send error: {e}")
def send_data(self, method, params):
msg = {'method': method, 'params': params}
try:
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
print(f"[{ts}] send_data: {method}={params}")
# 根据 train/predict 命令向 datamock 发送标签
if method == 'train':
if params == 0:
self._send_label(1)
print(f"[Label] train 0 -> datamock label=1")
elif params == 1:
self._send_label(2)
print(f"[Label] train 1 -> datamock label=2")
elif method == 'predict':
self._send_label(99)
print(f"[Label] predict -> datamock label=99")
except Exception as e:
print(f"[Client] send error: {e}")
# ========== 主流程 ==========
def run_headless():
server = ZmqResultServer(port=8088)
server.start()
_dh = str(IniRead('system', 'Decoder_Host'))
_dp = int(IniRead('system', 'Decoder_Port'))
client = ZmqCmdClient(_dh, _dp)
client.connect()
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
time.sleep(1) # 等待连接建立
client.send_data('decoderClass', 'ssmvep')
train_time = 2.5 # 每轮训练刺激时长 (s)
test_time = 2.5 # 每轮测试刺激时长 (s)
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
rest_time = float(IniRead('system', 'Rest_time'))
num_blocks = int(IniRead('system', 'Num_blocks'))
num_trials = int(IniRead('system', 'Num_trials'))
position = [0, 1]
truePos_seq = position * int(num_trials / len(position))
truePos_seq = np.random.permutation(truePos_seq).tolist()
user_choice = []
os.makedirs('EEGFiles', exist_ok=True)
seq_file_path = f'EEGFiles/pos_seq_{personname}{session}_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json'
seq_info = {
'position': position,
'sequence': truePos_seq,
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
with open(seq_file_path, 'w', encoding='utf-8') as f:
json.dump(seq_info, f, ensure_ascii=False, indent=2)
trained = 0
Num_Total = 0
Num_Success = 0
print("=" * 50)
print("[Headless] 开始运行 SSMVEP 通讯流程(无界面)")
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
print(f" train_time={train_time}s, test_time={test_time}s")
print("=" * 50)
try:
while True:
# -------- 个体校准阶段 --------
print("\n[Phase] 个体校准阶段 (paradigm=0)")
client.send_data('rest', 0)
time.sleep(1)
# epoch完成需要的额外等待时间train_latency=120包×20ms=2.4s
# 在train_time后需再等epoch_wait秒decoder才能完成epoch采集并取出数据
epoch_wait = 2.4 # 秒与train_latency对应
while server.paradigm == 0:
# 左腿刺激
print(f"\n[Train] 左腿刺激 (train 0) trained={trained}")
client.send_data('train', 0)
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
trained += 1
client.send_data('rest', 0)
time.sleep(max(0, abs(fault_rehabilitation - train_time) - epoch_wait))
# 右腿刺激
print(f"\n[Train] 右腿刺激 (train 1) trained={trained}")
client.send_data('train', 1)
time.sleep(train_time + epoch_wait) # 等待刺激时间+epoch完成时间
trained += 1
client.send_data('rest', 0)
time.sleep(max(0, fault_rehabilitation - epoch_wait))
# 个体校准阶段结束
print("\n[Phase] 个体校准结束,等待 paradigm=1 ...")
trained = 0
time.sleep(1)
# -------- 康复训练阶段 --------
while server.paradigm == 1:
print("\n[Phase] 康复训练阶段 (paradigm=1)")
for block_idx in range(num_blocks):
print(f"\n [Block {block_idx+1}/{num_blocks}]")
time.sleep(10) # 每轮开始前等待
for trial_idx in range(num_trials):
true_position = truePos_seq[trial_idx]
print(f" [Trial {trial_idx+1}/{num_trials}] true_pos={true_position}")
time.sleep(0.5) # 提示 + 叮声
server.ChoosenNum = -1
# 开始测试
# predict epoch latency = 115包×20ms = 2.3s需额外等待epoch完成
predict_epoch_wait = 2.3 # 秒与predict latency=115包对应
client.send_data('predict', 1)
t_start = time.perf_counter()
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
if server.ChoosenNum >= 0:
Num_Total += 1
user_choice.append(server.ChoosenNum)
if server.ChoosenNum in [0, 1]:
Num_Success += 1
rest_time = right_rehabilitation
break
time.sleep(0.02)
trained += 1
client.send_data('rest', 0)
time.sleep(0.5)
time.sleep(rest_time)
server.ChoosenNum = -1
# 训练结束
print("\n[Phase] 康复训练结束")
break # 退出康复训练循环
# 统计结果
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
expected_seq = truePos_seq * num_blocks
min_len = min(len(user_choice), len(expected_seq))
same_count = sum(1 for a, b in zip(user_choice[:min_len], expected_seq[:min_len]) if a == b)
true_accuracy = same_count / min_len if min_len > 0 else 0
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
print(f"[Result] TrueAcc={true_accuracy:.3f} ({same_count}/{min_len})")
break # 完成一个完整流程后退出
except KeyboardInterrupt:
print("\n[Headless] 用户中断")
finally:
client.send_data('predict', 2) # 关闭系统
client.send_data('saveData', 0)
server.stop()
print("[Headless] 已发送关闭指令,退出。")
if __name__ == '__main__':
run_headless()