Files
bci_algo/datamock.py
Ivey Song a9dbe7261b update
2026-06-09 19:30:27 +08:00

189 lines
7.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import zmq
import numpy as np
import time
import threading
from datetime import datetime
# ========== 参数配置 ==========
FS = 250 # 采样率 Hz
N_SAMPLES_PER_PKT = 5 # 每包采样点数
N_CHAN = 66 # 通道数: 64 EEG + 1 标签值 + 1 标签序号
EEG_FREQ = 10 # EEG 正弦波频率 Hz
EEG_AMP = 100.0 # EEG 幅值 100μV
LABEL_INTERVAL = 5 # 标签间隔秒数
SERVER_ADDR = 'tcp://127.0.0.1:8100'
LABEL_CMD_ADDR = 'tcp://127.0.0.1:8101' # 接收来自上位机范式的标签命令
# 发送间隔: 每包 5 采样点 / 250Hz = 20ms
PKT_INTERVAL = N_SAMPLES_PER_PKT / FS
def build_packet(global_sample_idx):
"""
生成一包 [5, 66] 的 float64 数据
:param global_sample_idx: 当前包第一个采样点在全局序列中的索引 (从 0 开始)
:return: np.ndarray shape [5, 66]
"""
# 当前包内 5 个采样点对应的时间(秒)
t = (global_sample_idx + np.arange(N_SAMPLES_PER_PKT)) / FS
# Ch0-63: EEG 10Hz 正弦波,幅值 100μV
# t shape [5,]sin 乘以标量后仍是 [5,],需要 reshape 为 [5,1] 再广播到 64 通道
eeg = (EEG_AMP * np.sin(2 * np.pi * EEG_FREQ * t)).reshape(N_SAMPLES_PER_PKT, 1) # [5, 1]
eeg = np.tile(eeg, (1, 64)) # [5, 64]
# Ch64: 标签值通道,初始化为 0
event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
# Ch65: 标签序号通道,初始化为 0
label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64)
# 拼成 [5, 66]
packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float64)
return packet
def should_send_label(global_sample_idx):
"""
判断当前包是否包含标签触发点(每 5s 的最后一个采样点)
采样点索引从 0 开始,每 5s = 1250 个采样点
最后一个采样点索引: 1249, 2499, 3749, ...
由于每包 5 个采样点,标签点落在包内的最后一个采样点位置
即当前包起始索引 global_sample_idx 必须使得:
global_sample_idx <= 标签点索引 < global_sample_idx + N_SAMPLES_PER_PKT
也就是 global_sample_idx <= 1249 < global_sample_idx + 5
即 global_sample_idx = 1245, 2495, 3745, ...
即 global_sample_idx = n * LABEL_INTERVAL * FS - N_SAMPLES_PER_PKT
"""
samples_per_interval = LABEL_INTERVAL * FS
# 检查当前包是否包含 interval 的最后一个采样点
# 标签点索引 = n * 1250 - 1当 global_sample_idx = n*1250-5 时,标签在包内索引 4
return (global_sample_idx + N_SAMPLES_PER_PKT - 1) % samples_per_interval == samples_per_interval - 1
def main():
ctx = zmq.Context()
sock = ctx.socket(zmq.DEALER)
sock.connect(SERVER_ADDR)
print(f"[{datetime.now().strftime('%H:%M:%S')}] ZMQ Dealer 连接到 {SERVER_ADDR}")
# ========== 上位机标签命令监听 ==========
# 使用线程安全的队列接收来自 ssmvep_main.py 的标签命令
# 标签值: 1 (train 0), 2 (train 1), 99 (predict)
pending_label = [None] # [label_value or None]
label_lock = threading.Lock()
label_cmd_sock = ctx.socket(zmq.PULL)
label_cmd_sock.bind(LABEL_CMD_ADDR)
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听绑定到 {LABEL_CMD_ADDR}")
stop_recv = threading.Event()
def label_cmd_thread():
"""监听来自上位机范式的标签命令,写入 pending_label"""
while not stop_recv.is_set():
try:
msg = label_cmd_sock.recv_string(zmq.NOBLOCK)
label_val = int(msg)
with label_lock:
pending_label[0] = label_val
ts = datetime.now().strftime('%H:%M:%S')
label_name = {1: 'train_0', 2: 'train_1', 99: 'predict'}.get(label_val, str(label_val))
print(f"[{ts}] 收到标签命令: {label_name} -> label={label_val}")
except zmq.Again:
time.sleep(0.005)
except Exception as e:
print(f"[label_cmd_thread] 错误: {e}")
time.sleep(0.01)
label_thread = threading.Thread(target=label_cmd_thread, daemon=True)
label_thread.start()
print(f"[{datetime.now().strftime('%H:%M:%S')}] 标签命令监听线程已启动")
# 后台消费线程:持续 recv 从 ROUTER 返回的数据,避免 server 发送队列积压
recv_count = [0]
def consumer_thread():
"""消费线程:阻塞 recv丢弃收到的数据仅用于清空 ROUTER 发送队列"""
while not stop_recv.is_set():
try:
frames = sock.recv_multipart(zmq.NOBLOCK)
recv_count[0] += 1
# 收到的格式: [identity, '', filtered_data_bytes]
if recv_count[0] % 500 == 0:
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已丢弃 {recv_count[0]} 帧滤波数据")
except zmq.Again:
time.sleep(0.01)
except zmq.error.Again: # 兼容旧版
time.sleep(0.01)
consumer = threading.Thread(target=consumer_thread, daemon=True)
consumer.start()
print(f"[{datetime.now().strftime('%H:%M:%S')}] 消费线程已启动daemon")
global_sample_idx = 0 # 全局采样点计数器
label_type = 1 # 当前标签类型: 1 或 2
label1_count = 0 # label=1 的序号计数器
label2_count = 0 # label=2 的序号计数器
packet_count = 0 # 已发送包数
print(f"[{datetime.now().strftime('%H:%M:%S')}] 开始发送模拟数据 ...")
print(f" 采样率: {FS}Hz | 每包 {N_SAMPLES_PER_PKT} 采样点 | 发送间隔 {PKT_INTERVAL*1000:.0f}ms")
print(f" EEG: {EEG_FREQ}Hz 正弦波 | 幅值 {EEG_AMP}μV")
print(f" 标签: 来自上位机范式命令 (train_0=1, train_1=2, predict=99)")
print("-" * 50)
try:
while True:
t_start = time.perf_counter()
# 构建当前包
packet = build_packet(global_sample_idx)
# 检查是否有来自上位机范式的挂起标签命令
with label_lock:
ext_label = pending_label[0]
if ext_label is not None:
pending_label[0] = None
if ext_label is not None:
# 将标签写入当前包所有5个采样点的第65通道 (index 64)
# 覆盖全部采样点确保 event_inner_idx 无论落在哪个位置都能被正确检测
packet[:, 64] = float(ext_label)
ts = datetime.now().strftime('%H:%M:%S')
print(f"[{ts}] 打标签: label={ext_label} -> ch64[all 5 samples] (global_sample_idx={global_sample_idx})")
# 发送: multipart 2帧 ['', data]
# 使用标准格式ROUTER 会自动附加 ZMQ 分配的客户端身份
sock.send_multipart([
b'',
packet.tobytes()
])
# 每 50 包打印一次进度
if packet_count % 50 == 0:
ts = datetime.now().strftime('%H:%M:%S')
print(f"[{ts}] 已发送 {packet_count} 包 (global_sample_idx={global_sample_idx})")
global_sample_idx += N_SAMPLES_PER_PKT
packet_count += 1
# 精确控制发送节奏: 等待到 PKT_INTERVAL 秒
elapsed = time.perf_counter() - t_start
sleep_time = PKT_INTERVAL - elapsed
if sleep_time > 0:
time.sleep(sleep_time)
except KeyboardInterrupt:
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] 停止发送,共发送 {packet_count}")
finally:
stop_recv.set()
consumer.join(timeout=2)
label_cmd_sock.close()
sock.close()
ctx.term()
if __name__ == '__main__':
main()