From 31d91d6cc7c6f20017aeebbc00943a15c0075cae Mon Sep 17 00:00:00 2001 From: lizhao Date: Mon, 8 Jun 2026 15:47:25 +0800 Subject: [PATCH] update float32 to float64 --- MI/Algorithm/conformer_2class.py | 2 +- MI/Algorithm/conformer_2class_cpu.py | 2 +- Zmq/dataBuffer.py | 2 +- Zmq/filterProcess.py | 2 +- Zmq/zmqServer.py | 4 ++-- datamock.py | 8 ++++---- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/MI/Algorithm/conformer_2class.py b/MI/Algorithm/conformer_2class.py index 8148b68..f2a02ae 100644 --- a/MI/Algorithm/conformer_2class.py +++ b/MI/Algorithm/conformer_2class.py @@ -82,7 +82,7 @@ class MultiHeadAttention(nn.Module): values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) if mask is not None: - fill_value = torch.finfo(torch.float32).min + fill_value = torch.finfo(torch.float64).min energy.mask_fill(~mask, fill_value) scaling = self.emb_size ** (1 / 2) diff --git a/MI/Algorithm/conformer_2class_cpu.py b/MI/Algorithm/conformer_2class_cpu.py index 6e29bc3..1ffe523 100644 --- a/MI/Algorithm/conformer_2class_cpu.py +++ b/MI/Algorithm/conformer_2class_cpu.py @@ -71,7 +71,7 @@ class MultiHeadAttention(nn.Module): values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) if mask is not None: - fill_value = torch.finfo(torch.float32).min + fill_value = torch.finfo(torch.float64).min energy.mask_fill(~mask, fill_value) scaling = self.emb_size ** (1 / 2) diff --git a/Zmq/dataBuffer.py b/Zmq/dataBuffer.py index f08a35b..233b72f 100644 --- a/Zmq/dataBuffer.py +++ b/Zmq/dataBuffer.py @@ -11,7 +11,7 @@ class ParadigmRingBuffer: def __init__(self, n_chan, n_points): self.n_chan = n_chan self.n_points = n_points - self.buffer = np.zeros((n_chan, n_points), dtype=np.float32) + self.buffer = np.zeros((n_chan, n_points), dtype=np.float64) self.currentPtr = 0 self.readPtr = 0 self.nUpdate = 0 diff --git a/Zmq/filterProcess.py b/Zmq/filterProcess.py index 7bafc6b..3f4dee7 100644 --- a/Zmq/filterProcess.py +++ b/Zmq/filterProcess.py @@ -18,7 +18,7 @@ class FilterRingBuffer: self.n_chan = n_chan self.n_points = n_points - self.buffer = np.zeros((n_chan, n_points), dtype=np.float32) + self.buffer = np.zeros((n_chan, n_points), dtype=np.float64) self.current_ptr = 0 # 写入指针:指向下一个要写入的位置 self.total_samples = 0 # 已写入总点数 self.lock = threading.Lock() # 线程安全锁 diff --git a/Zmq/zmqServer.py b/Zmq/zmqServer.py index d47e5e2..fa45850 100644 --- a/Zmq/zmqServer.py +++ b/Zmq/zmqServer.py @@ -174,7 +174,7 @@ class zmqServer(threading.Thread): return # 转置为上位机需要的[50, 通道数]格式 - filtered_data = filtered_data.T.astype(np.float32) + filtered_data = filtered_data.T.astype(np.float64) send_buf = filtered_data.tobytes() algo_log(f"发送滤波数据,长度: {len(send_buf)}字节, filtered_data.shape: {filtered_data.shape}", level="DEBUG") self.data_send_queue.put(send_buf) @@ -292,7 +292,7 @@ class zmqServer(threading.Thread): return # 零拷贝解析 + 维度转换 - data_np = np.frombuffer(data_bytes, dtype=np.float32) + data_np = np.frombuffer(data_bytes, dtype=np.float64) data_np = data_np.reshape(self.device_info['frame_points'], self.device_info['channel_nums']) data_np = data_np.T.astype(np.float64) diff --git a/datamock.py b/datamock.py index 7ad703d..5379281 100644 --- a/datamock.py +++ b/datamock.py @@ -18,7 +18,7 @@ PKT_INTERVAL = N_SAMPLES_PER_PKT / FS def build_packet(global_sample_idx): """ - 生成一包 [5, 66] 的 float32 数据 + 生成一包 [5, 66] 的 float64 数据 :param global_sample_idx: 当前包第一个采样点在全局序列中的索引 (从 0 开始) :return: np.ndarray shape [5, 66] """ @@ -31,13 +31,13 @@ def build_packet(global_sample_idx): eeg = np.tile(eeg, (1, 64)) # [5, 64] # Ch64: 标签值通道,初始化为 0 - event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float32) + event = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64) # Ch65: 标签序号通道,初始化为 0 - label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float32) + label_idx = np.zeros((N_SAMPLES_PER_PKT, 1), dtype=np.float64) # 拼成 [5, 66] - packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float32) + packet = np.concatenate([eeg, event, label_idx], axis=1).astype(np.float64) return packet