306 lines
12 KiB
Python
306 lines
12 KiB
Python
|
|
"""
|
|||
|
|
MI_headless.py
|
|||
|
|
无界面版 MI 运动想象范式通讯流程模拟脚本。
|
|||
|
|
复现 MI_main.py 的完整指令序列(train 0/1, rest, predict, saveData),
|
|||
|
|
但不依赖 psychopy 也不打开任何窗口/音频,用 time.sleep 替代帧循环等待。
|
|||
|
|
|
|||
|
|
启动顺序:
|
|||
|
|
1. runDecoder.py
|
|||
|
|
2. datamock.py
|
|||
|
|
3. MI_headless.py
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
import threading
|
|||
|
|
import zmq
|
|||
|
|
import numpy as np
|
|||
|
|
import ast
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
|
from PubLibrary.InifileHelper import IniRead
|
|||
|
|
|
|||
|
|
personname = 'demo'
|
|||
|
|
session = '01'
|
|||
|
|
|
|||
|
|
DATAMOCK_LABEL_ADDR = 'tcp://127.0.0.1:8101' # datamock 标签命令地址
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== ZMQ 结果接收服务 ==========
|
|||
|
|
class ZmqResultServer(threading.Thread):
|
|||
|
|
def __init__(self, port=8088):
|
|||
|
|
threading.Thread.__init__(self)
|
|||
|
|
self.port = port
|
|||
|
|
self.running = True
|
|||
|
|
self.energy = 0
|
|||
|
|
self.paradigm = 0 # 0=个体校准, 1=康复训练, 2=等待模型训练
|
|||
|
|
self.ChoosenNum = -1
|
|||
|
|
self.context = zmq.Context()
|
|||
|
|
self.socket = self.context.socket(zmq.ROUTER)
|
|||
|
|
self.socket.bind(f"tcp://0.0.0.0:{self.port}")
|
|||
|
|
self.daemon = True
|
|||
|
|
self.trial_idx = 0
|
|||
|
|
|
|||
|
|
def run(self):
|
|||
|
|
print(f"[Server] UpperHost_Server listening on {self.port}")
|
|||
|
|
while self.running:
|
|||
|
|
try:
|
|||
|
|
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
|||
|
|
if len(frames) < 3:
|
|||
|
|
continue
|
|||
|
|
message = json.loads(frames[2].decode('utf-8'))
|
|||
|
|
method = message.get('method')
|
|||
|
|
params = message.get('params')
|
|||
|
|
if method == 'energy':
|
|||
|
|
self.energy = params
|
|||
|
|
elif method == 'paradigm':
|
|||
|
|
self.paradigm = params
|
|||
|
|
print(f"[Server] paradigm -> {params}")
|
|||
|
|
elif method == 'result':
|
|||
|
|
self.ChoosenNum = params
|
|||
|
|
self.trial_idx += 1
|
|||
|
|
print(f"[Server] result={self.ChoosenNum} (trial {self.trial_idx})")
|
|||
|
|
except zmq.Again:
|
|||
|
|
time.sleep(0.005)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[Server] error: {e}")
|
|||
|
|
|
|||
|
|
def stop(self):
|
|||
|
|
self.running = False
|
|||
|
|
self.socket.close()
|
|||
|
|
self.context.term()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== ZMQ 命令发送客户端 ==========
|
|||
|
|
class ZmqCmdClient:
|
|||
|
|
def __init__(self, host, port):
|
|||
|
|
self.host = host
|
|||
|
|
self.port = port
|
|||
|
|
self.context = zmq.Context()
|
|||
|
|
self.socket = self.context.socket(zmq.DEALER)
|
|||
|
|
# PUSH socket 用于向 datamock.py 发送标签命令
|
|||
|
|
self._label_sock = self.context.socket(zmq.PUSH)
|
|||
|
|
self._label_sock.connect(DATAMOCK_LABEL_ADDR)
|
|||
|
|
print(f"[Client] label PUSH connected to {DATAMOCK_LABEL_ADDR}")
|
|||
|
|
|
|||
|
|
def connect(self):
|
|||
|
|
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
|||
|
|
print(f"[Client] connected to {self.host}:{self.port}")
|
|||
|
|
|
|||
|
|
def start_recv_thread(self, result_server):
|
|||
|
|
"""启动后台线程,持续接收 decoder 通过 8099 ROUTER 回发的消息,并更新 result_server 的状态"""
|
|||
|
|
self._result_server = result_server
|
|||
|
|
self._stop_recv = threading.Event()
|
|||
|
|
|
|||
|
|
def _recv_loop():
|
|||
|
|
while not self._stop_recv.is_set():
|
|||
|
|
try:
|
|||
|
|
frames = self.socket.recv_multipart(zmq.NOBLOCK)
|
|||
|
|
# DEALER 收到的格式: [b'', json_bytes]
|
|||
|
|
data_bytes = frames[-1]
|
|||
|
|
message = json.loads(data_bytes.decode('utf-8'))
|
|||
|
|
method = message.get('method')
|
|||
|
|
params = message.get('params')
|
|||
|
|
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
|||
|
|
print(f"[{ts}] [CmdClient] recv: {method}={params}")
|
|||
|
|
if method == 'paradigm':
|
|||
|
|
self._result_server.paradigm = params
|
|||
|
|
print(f"[{ts}] [CmdClient] paradigm updated -> {params}")
|
|||
|
|
elif method == 'result':
|
|||
|
|
self._result_server.ChoosenNum = params
|
|||
|
|
self._result_server.trial_idx += 1
|
|||
|
|
print(f"[{ts}] [CmdClient] result={params} (trial {self._result_server.trial_idx})")
|
|||
|
|
elif method == 'energy':
|
|||
|
|
self._result_server.energy = params
|
|||
|
|
except zmq.Again:
|
|||
|
|
time.sleep(0.005)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[CmdClient recv] error: {e}")
|
|||
|
|
time.sleep(0.01)
|
|||
|
|
|
|||
|
|
self._recv_thread = threading.Thread(target=_recv_loop, daemon=True)
|
|||
|
|
self._recv_thread.start()
|
|||
|
|
print(f"[Client] 后台接收线程已启动(监听 decoder 8099 回发消息)")
|
|||
|
|
|
|||
|
|
def stop_recv_thread(self):
|
|||
|
|
if hasattr(self, '_stop_recv'):
|
|||
|
|
self._stop_recv.set()
|
|||
|
|
|
|||
|
|
def _send_label(self, label_value):
|
|||
|
|
"""向 datamock.py 发送标签命令"""
|
|||
|
|
try:
|
|||
|
|
self._label_sock.send_string(str(label_value), zmq.NOBLOCK)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[Client] label send error: {e}")
|
|||
|
|
|
|||
|
|
def send_data(self, method, params):
|
|||
|
|
msg = {'method': method, 'params': params}
|
|||
|
|
try:
|
|||
|
|
self.socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
|||
|
|
ts = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
|||
|
|
print(f"[{ts}] send_data: {method}={params}")
|
|||
|
|
# 根据 train/predict 命令向 datamock 发送标签
|
|||
|
|
if method == 'train':
|
|||
|
|
if params == 0:
|
|||
|
|
self._send_label(1)
|
|||
|
|
print(f"[Label] train 0 -> datamock label=1")
|
|||
|
|
elif params == 1:
|
|||
|
|
self._send_label(2)
|
|||
|
|
print(f"[Label] train 1 -> datamock label=2")
|
|||
|
|
elif method == 'predict':
|
|||
|
|
self._send_label(99)
|
|||
|
|
print(f"[Label] predict -> datamock label=99")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[Client] send error: {e}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 主流程 ==========
|
|||
|
|
def run_headless():
|
|||
|
|
server = ZmqResultServer(port=8088)
|
|||
|
|
server.start()
|
|||
|
|
|
|||
|
|
_dh = str(IniRead('system', 'Decoder_Host'))
|
|||
|
|
_dp = int(IniRead('system', 'Decoder_Port'))
|
|||
|
|
client = ZmqCmdClient(_dh, _dp)
|
|||
|
|
client.connect()
|
|||
|
|
client.start_recv_thread(server) # 启动后台接收线程,监听 decoder 8099 回发的 paradigm/result 消息
|
|||
|
|
|
|||
|
|
time.sleep(1) # 等待连接建立
|
|||
|
|
client.send_data('decoderClass', 'mi')
|
|||
|
|
|
|||
|
|
# MI_IntervalEpoch = [0.5, 4.5],trial时长 = 4.5-0.5 = 4.0s
|
|||
|
|
_mi_iv = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
|||
|
|
_trial_sec = float(_mi_iv[1] - _mi_iv[0])
|
|||
|
|
_margin = 1.0
|
|||
|
|
train_time = max(5.0, _trial_sec + _margin) # 训练刺激时长(与 MI_main.py 保持一致)
|
|||
|
|
|
|||
|
|
# MI epoch latency = interval_epoch[1] // 5 = (4.5*250)//5 = 225包 × 20ms = 4.5s
|
|||
|
|
# train_latency = 225包(MI中 train_latency == latency)
|
|||
|
|
# 在 train_time 后需再等 epoch_wait 秒,decoder 才能完成 epoch 采集
|
|||
|
|
epoch_wait = _mi_iv[1] / _mi_iv[1] * (_mi_iv[1] * 250 // 5) * 0.02 # = latency * 20ms
|
|||
|
|
# 更直接的计算:latency = interval_epoch[1] // 5 = int(4.5*250)//5 = 225,225*0.02 = 4.5s
|
|||
|
|
epoch_wait = (int(_mi_iv[1] * 250) // 5) * 0.02 # 4.5s
|
|||
|
|
|
|||
|
|
# predict epoch wait(与 train 相同,MI中 latency == train_latency)
|
|||
|
|
predict_epoch_wait = epoch_wait # 4.5s
|
|||
|
|
|
|||
|
|
test_time = 7.0 # 预测窗口时长(与 MI_main.py 保持一致)
|
|||
|
|
right_rehabilitation = float(IniRead('system', 'Right_rehabilitation'))
|
|||
|
|
fault_rehabilitation = float(IniRead('system', 'Fault_rehabilitation'))
|
|||
|
|
rest_time = float(IniRead('system', 'Rest_time'))
|
|||
|
|
|
|||
|
|
num_blocks = int(IniRead('system', 'Num_blocks'))
|
|||
|
|
num_trials = int(IniRead('system', 'Num_trials'))
|
|||
|
|
|
|||
|
|
trained = 0
|
|||
|
|
Num_Total = 0
|
|||
|
|
Num_Success = 0
|
|||
|
|
user_choice = []
|
|||
|
|
|
|||
|
|
print("=" * 50)
|
|||
|
|
print("[Headless] 开始运行 MI 通讯流程(无界面)")
|
|||
|
|
print(f" MI_IntervalEpoch={_mi_iv}, trial_sec={_trial_sec:.2f}s")
|
|||
|
|
print(f" train_time={train_time:.2f}s, epoch_wait={epoch_wait:.2f}s")
|
|||
|
|
print(f" test_time={test_time:.2f}s, predict_epoch_wait={predict_epoch_wait:.2f}s")
|
|||
|
|
print(f" num_blocks={num_blocks}, num_trials={num_trials}")
|
|||
|
|
print("=" * 50)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
while True:
|
|||
|
|
# -------- 个体校准阶段 --------
|
|||
|
|
print("\n[Phase] 个体校准阶段 (paradigm=0)")
|
|||
|
|
client.send_data('rest', 0)
|
|||
|
|
time.sleep(1)
|
|||
|
|
|
|||
|
|
while server.paradigm == 0:
|
|||
|
|
# 左侧 MI 刺激(train 0,label=1)
|
|||
|
|
print(f"\n[Train] 左侧 MI 刺激 (train 0) trained={trained}")
|
|||
|
|
client.send_data('rest', 0)
|
|||
|
|
time.sleep(0.5) # ding 提示后等待
|
|||
|
|
|
|||
|
|
client.send_data('train', 0)
|
|||
|
|
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
|
|||
|
|
|
|||
|
|
trained += 1
|
|||
|
|
client.send_data('rest', 0)
|
|||
|
|
time.sleep(1.0) # 类间休息
|
|||
|
|
|
|||
|
|
# 空闲态样本采集(train 1,label=2)
|
|||
|
|
print(f"\n[Train] 空闲态采集 (train 1) trained={trained}")
|
|||
|
|
client.send_data('train', 1)
|
|||
|
|
time.sleep(train_time + epoch_wait) # 等待刺激时间 + epoch 完成时间
|
|||
|
|
|
|||
|
|
trained += 1
|
|||
|
|
client.send_data('rest', 0)
|
|||
|
|
time.sleep(1.0) # 类间休息
|
|||
|
|
|
|||
|
|
# 个体校准阶段结束
|
|||
|
|
print("\n[Phase] 个体校准结束,等待模型训练 (paradigm=2) ...")
|
|||
|
|
trained = 0
|
|||
|
|
time.sleep(1)
|
|||
|
|
|
|||
|
|
# 等待模型训练完成 (paradigm=2 -> paradigm=1)
|
|||
|
|
while server.paradigm == 2:
|
|||
|
|
print("[Phase] 等待模型训练完成 ...")
|
|||
|
|
time.sleep(0.5)
|
|||
|
|
|
|||
|
|
# -------- 康复训练阶段 --------
|
|||
|
|
while server.paradigm == 1:
|
|||
|
|
print("\n[Phase] 康复训练阶段 (paradigm=1)")
|
|||
|
|
for block_idx in range(num_blocks):
|
|||
|
|
print(f"\n [Block {block_idx+1}/{num_blocks}]")
|
|||
|
|
time.sleep(10) # 每轮开始前等待
|
|||
|
|
|
|||
|
|
for trial_idx in range(num_trials):
|
|||
|
|
print(f" [Trial {trial_idx+1}/{num_trials}]")
|
|||
|
|
|
|||
|
|
time.sleep(0.5) # ding 提示
|
|||
|
|
server.ChoosenNum = -1
|
|||
|
|
|
|||
|
|
# 开始预测
|
|||
|
|
# MI predict epoch latency = 225包 × 20ms = 4.5s,需额外等待 epoch 完成
|
|||
|
|
client.send_data('predict', 1)
|
|||
|
|
t_start = time.perf_counter()
|
|||
|
|
while time.perf_counter() - t_start < test_time + predict_epoch_wait:
|
|||
|
|
if server.ChoosenNum >= 0:
|
|||
|
|
Num_Total += 1
|
|||
|
|
user_choice.append(server.ChoosenNum)
|
|||
|
|
if server.ChoosenNum == 0:
|
|||
|
|
Num_Success += 1
|
|||
|
|
rest_time = right_rehabilitation
|
|||
|
|
elif server.ChoosenNum == 1:
|
|||
|
|
rest_time = fault_rehabilitation
|
|||
|
|
break
|
|||
|
|
time.sleep(0.02)
|
|||
|
|
|
|||
|
|
trained += 1
|
|||
|
|
client.send_data('rest', 0)
|
|||
|
|
time.sleep(0.5)
|
|||
|
|
time.sleep(rest_time)
|
|||
|
|
server.ChoosenNum = -1
|
|||
|
|
|
|||
|
|
# 训练结束
|
|||
|
|
print("\n[Phase] 康复训练结束")
|
|||
|
|
break # 退出康复训练循环
|
|||
|
|
|
|||
|
|
# 统计结果
|
|||
|
|
overall_accuracy = Num_Success / Num_Total if Num_Total > 0 else 0
|
|||
|
|
print(f"\n[Result] Overall={overall_accuracy:.3f} ({Num_Success}/{Num_Total})")
|
|||
|
|
print(f"[Result] user_choice={user_choice}")
|
|||
|
|
break # 完成一个完整流程后退出
|
|||
|
|
|
|||
|
|
except KeyboardInterrupt:
|
|||
|
|
print("\n[Headless] 用户中断")
|
|||
|
|
finally:
|
|||
|
|
client.send_data('predict', 2) # 关闭系统
|
|||
|
|
client.send_data('saveData', 0)
|
|||
|
|
server.stop()
|
|||
|
|
print("[Headless] 已发送关闭指令,退出。")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
run_headless()
|