original push

This commit is contained in:
Ivey Song
2026-06-01 13:18:36 +08:00
commit 8426770db6
46 changed files with 341750 additions and 0 deletions

55
algorithm_V0/.gitignore vendored Normal file
View File

@@ -0,0 +1,55 @@
# Byte-compiled / optimized / DLL files
__pycache__/
# Distribution / packaging
build/
dist/
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# data format
*.dat
*.csv
*.edf
*.event
*.edf.event
*.zip
*.xlsx
*.mat
*.json
# PyCharm
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
# https://github.com/github/gitignore/blob/main/Python.gitignore
.idea/
# Logs
*.log
# Other common ignores
node_modules/
dist/
tmp/
temp/
# Project-specific ignores
# Ignore all directories in the root
# merge64ch_0127/
/P300_speller/braindecode/
/P300_speller/data/
/P300_speller/pyRiemann/
/P300_speller/README/
/merge64ch_new/
/merge64ch_tianjinZMQdebug/

View File

@@ -0,0 +1,96 @@
# -*- mode: python ; coding: utf-8 -*-
import sys
import os
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
# ========================================================
# 1. 工程配置区 (Project Config)
# ========================================================
block_cipher = None
ENTRY_POINT = 'runDecoder.py'
APP_NAME = 'Depression_Decoder'
# ========================================================
# 2. 依赖分析 (Dependency Analysis)
# ========================================================
# 收集 mne, sklearn, scipy 可能遗漏的隐藏导入
hidden_imports = [
'infer_pth', # 你的动态导入模块
'sklearn.utils._cython_blas',
'sklearn.neighbors.typedefs',
'sklearn.neighbors.quad_tree',
'sklearn.tree',
'sklearn.tree._utils',
]
# 自动收集 mne 的子模块
hidden_imports += collect_submodules('mne')
# 收集 torch 相关的隐式导入(虽然 PyInstaller 通常能处理,但显式更安全)
hidden_imports += ['torch', 'torchvision']
# ========================================================
# 3. 资源锚定 (Data Anchoring)
# ========================================================
# Analysis 中的 datas 用于将文件嵌入到内部(运行时在临时目录或 _internal
# 这里我们留空,改为在 COLLECT 阶段通过 Tree 显式复制到 EXE 旁,
# 这样生成的文件夹里能直接看到 model 和 raw_data
datas = []
# 收集 mne 的数据文件(如果需要默认配置)
datas += collect_data_files('mne')
# ========================================================
# 4. 构建流程 (Build Process)
# ========================================================
a = Analysis(
[ENTRY_POINT],
pathex=[],
binaries=[],
datas=datas,
hiddenimports=hidden_imports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython'], # 排除 GUI 和交互式库减小体积
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name=APP_NAME,
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
# ========================================================
# 5. 打包模式: OneDir (单文件夹) + 资源旁路
# ========================================================
# 使用 Tree 将文件夹原样复制到 dist/APP_NAME/ 下
# 格式: Tree('源路径', prefix='目标子目录')
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=False,
upx_exclude=[],
name=APP_NAME,
)

View File

@@ -0,0 +1,87 @@
import os
import subprocess
import sys
import shutil
# 确保我们在虚拟环境中运行
if not sys.prefix == sys.base_prefix:
print(f"正在使用虚拟环境: {sys.prefix}")
else:
print("警告:你似乎没有激活虚拟环境!建议在 venv_clean 下运行。")
def build():
entry_point = "runDecoder.py" # 你的入口文件
# 自动清理逻辑优化
output_dir = "dist2"
build_dir = "build2" # Nuitka 默认会在当前目录生成 .build 文件夹
if "--clean" in sys.argv:
print("清理旧构建目录...")
for folder in [output_dir, build_dir, entry_point.replace(".py", ".build")]:
if os.path.exists(folder):
shutil.rmtree(folder, ignore_errors=True)
# Nuitka 命令 - 此时非常清爽
nuitka_cmd = [
sys.executable, "-m", "nuitka",
"--standalone", # 独立运行模式
f"--output-dir={output_dir}", # 输出目录
"--show-progress", # 显示进度
"--assume-yes-for-downloads", # 自动下载依赖(如 ccache, depends 等)
# --- 插件配置 ---
"--enable-plugin=numpy",
"--enable-plugin=matplotlib",
"--enable-plugin=torch", # 处理 PyTorch 及其 CUDA 依赖
# --- 包含包/模块 (Nuitka 2.x 推荐使用 include-package-data 或 include-package) ---
# --collect-all 是 PyInstaller 的参数Nuitka 不支持
"--include-package=sklearn",
"--include-package=scipy",
"--include-package=mne",
# 强制包含 MNE 的数据文件(配置、布局等)
"--include-package-data=mne",
"--include-package=PIL", # Pillow (matplotlib/mne 可能用到)
"--include-package=networkx", # mne 可能用到
"--include-package=decorator", # MNE 核心依赖,防止 KeyError: 'self'
"--include-package=six", # 通用兼容库
# 显式包含本地模块,防止隐式导入丢失
"--include-module=infer_pth",
# --- 数据文件 ---
# 格式: 源路径=目标路径 (相对 dist 目录)
"--include-data-dir=model=model",
"--include-data-dir=raw_data=raw_data",
# --- 排除干扰以减小体积/提高稳定性 ---
"--nofollow-import-to=pytest",
"--nofollow-import-to=unittest",
"--nofollow-import-to=pdb",
"--nofollow-import-to=tkinter", # 如果不用 GUI 界面
"--nofollow-import-to=sympy", # 除非明确用到符号计算
# --- 内存与性能 ---
"--low-memory", # 降低打包时的内存消耗
# --- Windows 特定 ---
# "--disable-console", # 如果不需要黑框,取消注释这一行
]
nuitka_cmd.append(entry_point)
print("开始打包...")
try:
subprocess.check_call(nuitka_cmd)
print("\n打包成功!")
print(f"请在 dist2/runDecoder.dist 目录下运行 exe 进行测试。")
except subprocess.CalledProcessError as e:
print(f"打包失败,错误码: {e.returncode}")
if __name__ == "__main__":
build()

View File

@@ -0,0 +1,77 @@
import os
import shutil
import subprocess
import sys
def main():
# 1. 定义路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DIST_DIR = os.path.join(BASE_DIR, 'dist')
APP_NAME = 'Depression_Decoder'
TARGET_DIR = os.path.join(DIST_DIR, APP_NAME)
MODEL_SRC = os.path.join(BASE_DIR, 'model')
RAW_DATA_SRC = os.path.join(BASE_DIR, 'raw_data')
MODEL_DST = os.path.join(TARGET_DIR, 'model')
RAW_DATA_DST = os.path.join(TARGET_DIR, 'raw_data')
# 2. 清理旧构建
print("[1/3] Cleaning up old builds...")
if os.path.exists(DIST_DIR):
try:
shutil.rmtree(DIST_DIR)
print(" Cleaned dist/")
except Exception as e:
print(f" Warning: Could not clean dist/: {e}")
BUILD_DIR = os.path.join(BASE_DIR, 'build')
if os.path.exists(BUILD_DIR):
try:
shutil.rmtree(BUILD_DIR)
print(" Cleaned build/")
except Exception as e:
print(f" Warning: Could not clean build/: {e}")
# 3. 运行 PyInstaller
print("[2/3] Running PyInstaller...")
# 注意:我们这里不传 --noupx因为已经在 spec 文件里把 upx=False 写死了
cmd = [
"pyinstaller",
"build_algorithm.spec",
"--clean"
]
try:
subprocess.check_call(cmd, shell=True)
except subprocess.CalledProcessError:
print("Error: PyInstaller failed.")
sys.exit(1)
# 4. 复制外部资源文件夹
print("[3/3] Copying external resources...")
# 复制 model 文件夹
if os.path.exists(MODEL_SRC):
if os.path.exists(MODEL_DST):
shutil.rmtree(MODEL_DST)
shutil.copytree(MODEL_SRC, MODEL_DST)
print(f" Copied: model -> {MODEL_DST}")
else:
print(f" Warning: Source model dir not found at {MODEL_SRC}")
# 复制 raw_data 文件夹
if os.path.exists(RAW_DATA_SRC):
if os.path.exists(RAW_DATA_DST):
shutil.rmtree(RAW_DATA_DST)
shutil.copytree(RAW_DATA_SRC, RAW_DATA_DST)
print(f" Copied: raw_data -> {RAW_DATA_DST}")
else:
print(f" Warning: Source raw_data dir not found at {RAW_DATA_SRC}")
print("\n" + "="*50)
print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}")
print("="*50)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
import os
import sys
import scipy
import numpy
print(f"Python executable: {sys.executable}")
print(f"Scipy version: {scipy.__version__}")
print(f"Scipy path: {scipy.__file__}")
scipy_dir = os.path.dirname(scipy.__file__)
parent_dir = os.path.dirname(scipy_dir)
scipy_libs = os.path.join(parent_dir, "scipy.libs")
print(f"Checking for scipy.libs at: {scipy_libs}")
if os.path.exists(scipy_libs):
print("scipy.libs FOUND.")
for root, dirs, files in os.walk(scipy_libs):
for f in files:
print(f" - {f}")
else:
print("scipy.libs NOT FOUND.")
print("-" * 20)
print(f"Numpy version: {numpy.__version__}")
print(f"Numpy path: {numpy.__file__}")
numpy_dir = os.path.dirname(numpy.__file__)
numpy_libs = os.path.join(numpy_dir, ".libs") # numpy 往往在内部
if not os.path.exists(numpy_libs):
# try parent
numpy_libs = os.path.join(os.path.dirname(numpy_dir), "numpy.libs")
print(f"Checking for numpy libs at: {numpy_libs}")
if os.path.exists(numpy_libs):
print("numpy libs FOUND.")
for root, dirs, files in os.walk(numpy_libs):
for f in files:
print(f" - {f}")
else:
print("numpy libs NOT FOUND.")

View File

@@ -0,0 +1,561 @@
# -*- coding: utf-8 -*-
"""
infer_pth.py
用途:
- 从一个文件夹中自动读取第一个 .mat EEG 文件64通道或32通道
- 若为64通道则按 idx64_to_32 映射选出32通道
- 提取切片特征DE + PSD(var近似)不含Asym
- 加载你训练好的 .pth 模型FusionNet结构
- 输出该受试者的 HC / MDD 判断结果
运行方式(命令行):
python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth"
也可在其他py里import
from infer_pth_from64_to32 import predict_hc_mdd
res = predict_hc_mdd(eeg_dir, model_path)
print(res)
"""
from __future__ import annotations
import os
import argparse
import numpy as np
import scipy.io
import scipy.signal as signal
import torch
import torch.nn as nn
import torch.nn.functional as F
# =========================================================
# 0) 配置区(按需改这里)
# =========================================================
# 采样率(必须与训练时一致)
SAMPLING_RATE = 250
# 滑窗参数(必须与训练时一致)
WINDOW_SIZE = 500
STRIDE = 250
# 频段(必须与训练时一致)
BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
BANDS = {
"Delta": (1, 4),
"Theta": (4, 8),
"Alpha": (8, 13),
"Beta": (13, 30),
"Gamma": (30, 50),
}
# 是否使用扩展特征DE+PSD
USE_EXTENDED_FEATURES = True
# 数值稳定项
EPS = 1e-12
# 设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 通道映射
IDX64_TO_32 = [
23, # C5
47, # O1
39, # TP7
6, # FPZ
2, # PO6
21, # P4
35, # AF7
57, # AF3
1, # FP2
37, # T7
63, # F1
36, # A1
18, # FC4
31, # FC5
14, # FC2
48, # T8
60, # P2
41, # AF8
11, # CP1
0, # FP1
55, # PO7
59, # C1
22, # F5
10, # CP2
16, # C3
61, # P1
27, # CP5
17, # C4
26, # CP6
62, # F2
3, # POZ
13, # PO5
]
# 推理阈值如果模型checkpoint里有 subject_threshold会优先用它否则用这个
DEFAULT_SUBJECT_THRESHOLD = 0.5
# =========================================================
# 1) 模型结构
# =========================================================
class SEBlock(nn.Module):
def __init__(self, channels: int, reduction: int = 4) -> None:
super().__init__()
hidden = max(1, channels // reduction)
self.fc1 = nn.Linear(channels, hidden)
self.fc2 = nn.Linear(hidden, channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
se = F.relu(self.fc1(x))
se = torch.sigmoid(self.fc2(se))
return x * se
class ResidualBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None:
super().__init__()
self.fc1 = nn.Linear(in_features, out_features)
self.bn1 = nn.BatchNorm1d(out_features)
self.fc2 = nn.Linear(out_features, out_features)
self.bn2 = nn.BatchNorm1d(out_features)
self.dropout = nn.Dropout(dropout)
self.shortcut = nn.Identity()
if in_features != out_features:
self.shortcut = nn.Sequential(
nn.Linear(in_features, out_features),
nn.BatchNorm1d(out_features),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
out = F.relu(self.bn1(self.fc1(x)))
out = self.dropout(out)
out = self.bn2(self.fc2(out))
out = F.relu(out + identity)
return out
class FusionNet(nn.Module):
def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None:
super().__init__()
self.input_norm = nn.BatchNorm1d(num_eeg_features)
self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4)
self.block2 = ResidualBlock(512, 256, dropout=0.3)
self.block3 = ResidualBlock(256, 128, dropout=0.2)
self.attention = SEBlock(128, reduction=4)
self.final_fc = nn.Sequential(
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.2),
)
self.cls_head = nn.Linear(64, num_classes)
# 训练时有回归头也没关系推理只用cls
self.reg_head = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(32, num_scales),
)
self._init_weights()
def _init_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor):
x = self.input_norm(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.attention(x)
features = self.final_fc(x)
cls_out = self.cls_head(features)
reg_out = self.reg_head(features)
return cls_out, reg_out
# =========================================================
# 2) mat读取 + 通道裁剪
# =========================================================
def _find_first_mat_file(folder: str) -> str:
if not os.path.isdir(folder):
raise RuntimeError(f"eeg_dir 不是文件夹: {folder}")
mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")])
if not mats:
raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}")
return os.path.join(folder, mats[0])
import numpy as np
import scipy.io
def _unwrap_singleton(x):
"""
把 (1,1) / (1,) 这种包裹层一直剥掉,直到不是 singleton。
也处理 object array 的情况。
"""
while True:
if isinstance(x, np.ndarray):
if x.dtype == object and x.size == 1:
x = x.item()
continue
if x.size == 1 and x.ndim >= 1:
# 例如 (1,1) 或 (1,) 的数值/对象数组
try:
x = x.reshape(-1)[0]
continue
except Exception:
pass
break
return x
def _try_get_struct_field(v, field_name="data"):
"""
尝试从以下几种结构中提取字段:
1) scipy 读出的 mat_struct有 _fieldnames
2) numpy structured/record arraydtype.names
"""
# case 1: mat_struct推荐 loadmat(..., struct_as_record=False, squeeze_me=True)
if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])):
return getattr(v, field_name)
# case 2: structured array
if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names):
# 常见是 v[field] 仍然是 ndarray / object需要 unwrap
try:
return v[field_name]
except Exception:
return None
return None
def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray:
"""
读取 .mat 中 EEG 数据,支持:
- 直接二维矩阵 (T,C) 或 (C,T)
- struct 里有字段 data
返回统一为 float32 的 (T, C)
"""
# 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
candidates = []
for k, v in mat.items():
if k.startswith("__"):
continue
# --- 1) 直接二维数值矩阵 ---
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
candidates.append((k, v))
continue
# --- 2) struct/record优先提取 data 字段 ---
data_field = _try_get_struct_field(v, "data")
if data_field is not None:
data_field = _unwrap_singleton(data_field)
# data_field 可能仍然被 object 包一层
if isinstance(data_field, np.ndarray) and data_field.dtype == object:
data_field = _unwrap_singleton(data_field)
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
# 只收数值矩阵
if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object:
candidates.append((f"{k}.data", data_field))
continue
# --- 3) object array尝试 item() 解包后再看是不是二维数值矩阵/struct ---
if isinstance(v, np.ndarray) and v.dtype == object:
vv = _unwrap_singleton(v)
# 解包后若是二维数值矩阵
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
candidates.append((k, vv))
continue
# 解包后若是 struct再取 data
data2 = _try_get_struct_field(vv, "data")
if data2 is not None:
data2 = _unwrap_singleton(data2)
if isinstance(data2, np.ndarray) and data2.ndim == 2:
candidates.append((f"{k}.data", data2))
continue
if not candidates:
raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data{mat_path}")
# 选一个最像 EEG 的优先含32/64通道维度的
def score(arr: np.ndarray) -> int:
s = 0
if 64 in arr.shape: s += 10
if 32 in arr.shape: s += 9
if 128 in arr.shape: s += 8
if 129 in arr.shape: s += 7
s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG
return s
candidates.sort(key=lambda kv: score(kv[1]), reverse=True)
key, eeg = candidates[0]
eeg = _unwrap_singleton(eeg)
# 如果还是 object dtype尽力转成 float
if isinstance(eeg, np.ndarray) and eeg.dtype == object:
# 有时 object 里其实是数值
eeg = np.array(eeg, dtype=np.float32)
else:
eeg = np.asarray(eeg, dtype=np.float32)
if eeg.ndim != 2:
raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}")
# 统一为 (T, C)
# 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
eeg = eeg.T
elif eeg.shape[1] in (32, 64, 128, 129):
# 如果第一维也是这些数且更小,可能是(C,T)
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
eeg = eeg.T
return eeg
def ensure_32_channels(eeg: np.ndarray) -> np.ndarray:
"""
输入 (T, C),输出 (T, 32)
- 若C=64按 IDX64_TO_32 选32通道
- 若C=32直接返回
"""
if eeg.ndim != 2:
raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}")
C = eeg.shape[1]
if C == 64:
idx = np.asarray(IDX64_TO_32, dtype=np.int64)
if idx.min() < 0 or idx.max() >= 64:
raise RuntimeError(f"IDX64_TO_32 越界min={idx.min()}, max={idx.max()} (要求0~63)")
return eeg[:, idx]
if C == 32:
return eeg
raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。")
# =========================================================
# 3) 特征提取DE + PSD(var近似)
# =========================================================
class FeatureExtractor32:
"""
只针对32通道输出维度
- USE_EXTENDED_FEATURES=TrueDE(32*5) + PSD(32*5) = 320
- 否则DE(32*5) = 160
"""
def __init__(
self,
fs: int = SAMPLING_RATE,
window_size: int = WINDOW_SIZE,
stride: int = STRIDE,
filter_order: int = 4,
zero_phase: bool = False,
) -> None:
self.fs = fs
self.window_size = window_size
self.stride = stride
self.filter_order = filter_order
self.zero_phase = zero_phase
self._sos = {}
for bn in BAND_NAMES:
low, high = BANDS[bn]
self._sos[bn] = signal.butter(
self.filter_order, [low, high],
btype="band", fs=self.fs, output="sos"
)
def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]:
out = {}
for bn in BAND_NAMES:
sos = self._sos[bn]
if self.zero_phase:
out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32)
else:
out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32)
return out
def extract(self, eeg32: np.ndarray) -> np.ndarray:
"""
eeg32: (T, 32)
return: feats (N_slices, feat_dim)
"""
if eeg32.ndim != 2 or eeg32.shape[1] != 32:
raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}")
bands_data = self._filter_bands(eeg32)
T = eeg32.shape[0]
feats = []
for start in range(0, T - self.window_size, self.stride):
end = start + self.window_size
de_list = []
psd_list = []
for bn in BAND_NAMES:
seg = bands_data[bn][start:end, :] # (W, 32)
var = np.var(seg, axis=0, ddof=1) # (32,)
# DE
de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS))
de_list.append(de)
if USE_EXTENDED_FEATURES:
# PSD近似log(var)
psd_list.append(np.log(var + EPS))
de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,)
if USE_EXTENDED_FEATURES:
psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,)
f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32)
else:
f = de_feat.astype(np.float32)
feats.append(f)
if not feats:
raise RuntimeError("EEG长度不足以切片请检查T是否太短或调整WINDOW_SIZE/STRIDE")
return np.stack(feats, axis=0).astype(np.float32)
# =========================================================
# 4) 模型加载 + 推理接口
# =========================================================
def _safe_torch_load(path: str):
try:
return torch.load(path, map_location=DEVICE, weights_only=False)
except TypeError:
return torch.load(path, map_location=DEVICE)
def load_model(model_path: str) -> tuple[FusionNet, dict]:
"""
返回: (model, ckpt_dict)
"""
obj = _safe_torch_load(model_path)
if isinstance(obj, dict) and "model_state" in obj:
ckpt = obj
state = obj["model_state"]
feat_dim = int(obj.get("feat_dim", 320))
else:
ckpt = {}
state = obj
feat_dim = 320
model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE)
model.load_state_dict(state, strict=True)
model.eval()
return model, ckpt
def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict:
"""
接口:传入 EEG文件夹 和 模型路径,返回判断结果 dict
返回字段:
- mat_file: 使用的mat文件
- pred_label: "HC" or "MDD"
- p_mdd_mean: 切片p(MDD)均值
- threshold: subject判定阈值
- n_slices: 切片数
"""
mat_file = _find_first_mat_file(eeg_dir)
# 1) 读EEG (T,C)并保证变成32通道
eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C)
eeg32 = ensure_32_channels(eeg) # (T,32)
# 2) 提特征 (N,feat_dim)
extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE)
feats = extractor.extract(eeg32) # (N, dim)
# 3) 加载模型
model, ckpt = load_model(model_path)
# 4) 可选归一化若ckpt里保存的mean/std维度刚好匹配
mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None
std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None
if mean is not None and std is not None:
mean = np.asarray(mean, dtype=np.float32)
std = np.asarray(std, dtype=np.float32)
if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]:
feats = (feats - mean) / (std + 1e-8)
# 不匹配就跳过
# 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理
# 5) 推理:对所有切片算 p(MDD)取均值做subject-level
x = torch.from_numpy(feats).to(DEVICE)
with torch.no_grad():
cls_out, _ = model(x)
prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy()
p_mdd_mean = float(np.mean(prob_mdd))
thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD)
pred_is_mdd = (p_mdd_mean >= thr)
pred_label = "MDD" if pred_is_mdd else "HC"
return {
"mat_file": mat_file,
"pred_label": pred_label,
"p_mdd_mean": p_mdd_mean,
"threshold": thr,
"n_slices": int(feats.shape[0]),
}
# =========================================================
# 5) CLI命令行运行入口
# =========================================================
def main():
parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).")
parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹自动读取第一个.mat")
parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径")
args = parser.parse_args()
res = predict_hc_mdd(args.eeg_dir, args.model_path)
print("\n========== 推理结果 ==========")
print(f"MAT文件: {res['mat_file']}")
print(f"切片数量: {res['n_slices']}")
print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}")
print(f"阈值thr: {res['threshold']:.4f}")
print(f"预测结果: {res['pred_label']}")
print("==============================\n")
if __name__ == "__main__":
main()

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 306 KiB

View File

@@ -0,0 +1,9 @@
中央区α/β波比值:1.2
额区α/β波比值:1.3
顶区α/β波比值:1.2
中央区θ/β波比值:3.2
顶区θ/β波比值:3.5
前额叶α波不对称性:0.3
个体化α峰值频率:8.5
前额叶θ+δ波功率:93.8
是否推荐治疗:否

Binary file not shown.

After

Width:  |  Height:  |  Size: 268 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 493 KiB

View File

@@ -0,0 +1,6 @@
numpy
scipy
matplotlib
mne
torch
scikit-learn

View File

@@ -0,0 +1,909 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
"""
run_metrics_and_figs.py
1) 自动读取 mat_dir 中排序后的第一个 .mat
2) 调用模型预测HC/MDD并写 ResultData.txt
3) 同时保存图片EEG.png / psd.png / average_topomap.png / topomaps.png
"""
import matplotlib
matplotlib.use('Agg')
import numpy as np
import os
import shutil
import scipy.io
import scipy.signal as signal
import matplotlib.pyplot as plt
import mne
from mne.preprocessing import ICA
# ==========================
# Config
# ==========================
PREPROCESS_BANDPASS = (0.8, 30.0)
PREPROCESS_NOTCH = [50, 100]
PREPROCESS_ICA_N = 0.99
PREPROCESS_ICA_SEED = 97
PREPROCESS_APPLY_AVG_REF = True
PREPROCESS_BAD_PTP_UV = 350.0 # 坏段阈值 (μV)
DEFAULT_FS = 250.0
EEG_PLOT_SECONDS = 10
PSD_FMIN, PSD_FMAX = 0.8, 45.0
EPS = 1e-12
FIXED_EEG_IDXS = [23, 47, 39, 6, 2, 21, 35, 57] # 0-based index, 按重要性排序
FIXED_EEG_LABELS = ["C5", "O1", "TP7", "FPZ", "PO6", "P4", "AF7", "AF3"]
BANDS_METRICS = {
"Delta": (1.0, 4.0),
"Theta": (4.0, 8.0),
"Alpha": (8.0, 13.0),
"Beta": (13.0, 30.0),
}
TOTAL_POWER_BAND = (1.0, 50.0)
BANDS_TOPOMAP = {
"delta": (0.8, 3.9),
"theta": (4.0, 7.9),
"alpha": (8.0, 12.9),
"beta": (13.0, 30.0),
"broad": (0.8, 30.0),
}
# ==========================
# 预处理逻辑
# ==========================
def annotate_bad_segments(raw, peak_to_peak_uv=250.0):
"""
简单坏段检测:按固定窗口计算峰峰值,超过阈值标为 bad。
"""
peak_to_peak_v = peak_to_peak_uv * 1e-6
win = int(raw.info["sfreq"] * 1.0)
step = int(raw.info["sfreq"] * 0.5)
data = raw.get_data()
n_times = data.shape[1]
onsets = []
durations = []
descriptions = []
for start in range(0, n_times - win, step):
seg = data[:, start:start + win]
ptp = np.ptp(seg, axis=1)
if np.any(ptp > peak_to_peak_v):
onsets.append(start / raw.info["sfreq"])
durations.append(win / raw.info["sfreq"])
descriptions.append("BAD_PTP")
if len(onsets) > 0:
ann = mne.Annotations(onset=onsets, duration=durations, description=descriptions)
raw.set_annotations(ann)
print(f"[INFO] Annotated bad segments: {len(onsets)} windows")
else:
print("[INFO] No bad segments detected by PTP rule")
def run_preprocess_on_raw(raw: mne.io.RawArray) -> mne.io.RawArray:
"""
核心预处理:滤波 + 平均参考 + 坏段标注 + ICA
"""
# 1) 滤波
raw.filter(PREPROCESS_BANDPASS[0], PREPROCESS_BANDPASS[1], fir_design="firwin", verbose=False)
raw.notch_filter(PREPROCESS_NOTCH, fir_design="firwin", verbose=False)
# 2) 平均参考
if PREPROCESS_APPLY_AVG_REF:
raw.set_eeg_reference("average", verbose=False)
# 3) 坏段标注
annotate_bad_segments(raw, peak_to_peak_uv=PREPROCESS_BAD_PTP_UV)
# 4) ICA
ica = ICA(
n_components=PREPROCESS_ICA_N,
random_state=PREPROCESS_ICA_SEED,
max_iter=800,
method="fastica"
)
ica.fit(raw, reject_by_annotation=True, verbose=False)
try:
eog_inds, _ = ica.find_bads_eog(raw, verbose=False)
if eog_inds:
ica.exclude.extend(eog_inds)
print(f"[INFO] ICA exclude EOG comps: {eog_inds}")
except Exception as e:
print(f"[WARN] ICA find_bads_eog skipped: {e}")
raw_clean = ica.apply(raw.copy(), verbose=False)
return raw_clean
def preprocess_mat_file(src_mat_path: str, temp_out_dir: str) -> str:
"""
读取原始mat -> 预处理 -> 保存到 temp_out_dir -> 返回新路径
"""
os.makedirs(temp_out_dir, exist_ok=True)
# 1. 读原始 mat
# 注意:这里我们只要数据部分转成 MNE Raw然后处理再存回
# 复用现有的 load_eeg_from_mat 拿到 ndarray
eeg_uV, fs, ch_names, xyz = load_eeg_from_mat(src_mat_path)
# 转 MNE (注意单位uV -> V)
if not ch_names:
ch_names = [f"CH{i+1}" for i in range(eeg_uV.shape[1])]
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"] * len(ch_names))
raw = mne.io.RawArray(eeg_uV.T * 1e-6, info, verbose=False)
if xyz is not None and isinstance(xyz, np.ndarray):
# 尝试设 montage虽然对滤波不关键但尽量保留信息
try:
ch_pos = {ch_names[i]: xyz[i, :] for i in range(len(ch_names))}
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame="head")
raw.set_montage(montage, on_missing="ignore")
except Exception:
pass
# 2. 执行预处理
print(f"[INFO] Start preprocessing: {src_mat_path}")
raw_clean = run_preprocess_on_raw(raw)
# 3. 存回 .mat (保持结构兼容,以便后续 run_all 读取)
# 这里我们需要读取原始 mat 的结构体,把 data 替换掉
try:
mat_struct = scipy.io.loadmat(src_mat_path, struct_as_record=False, squeeze_me=True)
if "eeg" in mat_struct:
eeg_obj = mat_struct["eeg"]
# 替换数据MNE (V) -> uV -> (T, C)
clean_data_uV = (raw_clean.get_data() * 1e6).T
eeg_obj.data = clean_data_uV
base_name = os.path.basename(src_mat_path)
new_path = os.path.join(temp_out_dir, base_name)
scipy.io.savemat(new_path, {"eeg": eeg_obj}, do_compression=True)
print(f"[INFO] Preprocessed file saved to: {new_path}")
return new_path
except Exception as e:
print(f"[WARN] Failed to preserve original struct structure: {e}")
# Fallback: 如果读原始结构失败,就存一个简单的 mat
clean_data_uV = (raw_clean.get_data() * 1e6).T
out_dict = {
"eeg": {
"data": clean_data_uV,
"sample_rate": fs,
"electrode_name": ch_names,
"electrode_xyz": xyz if xyz is not None else []
}
}
base_name = os.path.basename(src_mat_path)
new_path = os.path.join(temp_out_dir, base_name)
scipy.io.savemat(new_path, out_dict, do_compression=True)
print(f"[INFO] Preprocessed file saved (fallback mode) to: {new_path}")
return new_path
# ==========================
# 输出目录
# ==========================
def ensure_outdir(out_root: str) -> str:
"""
确保输出目录存在,并清空除 ResultData.txt 之外的旧文件。
不再创建 timestamp 子文件夹,直接输出到 out_root。
"""
if os.path.exists(out_root):
# 清空目录,但保留 ResultData.txt
for filename in os.listdir(out_root):
if filename == "ResultData.txt":
continue
file_path = os.path.join(out_root, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"[WARN] Failed to delete {file_path}. Reason: {e}")
else:
os.makedirs(out_root, exist_ok=True)
return out_root
# ==========================
# 单位自动识别:统一到 μV
# ==========================
def _auto_scale_to_uV(data_nt_nc: np.ndarray):
data = np.asarray(data_nt_nc)
p95 = float(np.percentile(np.abs(data), 95))
if p95 <= 0.5:
data_uV = data * 1e6
msg = f"[UNIT] p95={p95:.3g} -> assume V, convert to μV by *1e6"
elif p95 > 5000:
data_uV = data * 1e-3
msg = f"[UNIT] p95={p95:.3g} -> assume nV, convert to μV by /1000"
else:
data_uV = data
msg = f"[UNIT] p95={p95:.3g} -> assume μV, no scaling"
p95_uV = float(np.percentile(np.abs(data_uV), 95))
warn = None
if p95_uV > 5000:
warn = f"[WARN] After scaling, p95 still large: {p95_uV:.3g} μV"
elif p95_uV < 0.1:
warn = f"[WARN] After scaling, p95 still small: {p95_uV:.3g} μV"
return data_uV, msg, warn
# ==========================
# mat 读取(支持 struct.data / electrode_name / electrode_xyz / sample_rate
# ==========================
def _unwrap_singleton(x):
while True:
if isinstance(x, np.ndarray):
if x.dtype == object and x.size == 1:
x = x.item()
continue
if x.size == 1 and x.ndim >= 1:
try:
x = x.reshape(-1)[0]
continue
except Exception:
pass
break
return x
def _try_get_struct_field(v, field_name="data"):
if hasattr(v, "_fieldnames") and field_name in getattr(v, "_fieldnames", []):
return getattr(v, field_name)
if isinstance(v, np.ndarray) and v.dtype.names and field_name in v.dtype.names:
try:
return v[field_name]
except Exception:
return None
return None
def _extract_electrode_names(st):
nf = _try_get_struct_field(st, "electrode_name")
if nf is None:
return None
nf = _unwrap_singleton(nf)
if isinstance(nf, (list, tuple)):
names = [str(x).strip() for x in nf]
return names if names else None
if isinstance(nf, np.ndarray):
flat = nf.reshape(-1)
names = [str(_unwrap_singleton(x)).strip() for x in flat]
return names if names else None
s = str(nf).strip()
return [s] if s else None
def _extract_sample_rate(st):
sr = _try_get_struct_field(st, "sample_rate")
if sr is None:
return None
sr = _unwrap_singleton(sr)
try:
return float(sr)
except Exception:
return None
def _extract_xyz(st):
xyz = _try_get_struct_field(st, "electrode_xyz")
if xyz is None:
return None
xyz = _unwrap_singleton(xyz)
try:
xyz = np.asarray(xyz, dtype=float)
if xyz.ndim == 2 and xyz.shape[1] == 3:
return xyz
if xyz.ndim == 2 and xyz.shape[0] == 3:
return xyz.T
return None
except Exception:
return None
def load_eeg_from_mat(mat_path: str):
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
candidates = []
st_for_meta = None
for k, v in mat.items():
if k.startswith("__"):
continue
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
candidates.append((k, v, None))
continue
data_field = _try_get_struct_field(v, "data")
if data_field is not None:
data_field = _unwrap_singleton(data_field)
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
candidates.append((f"{k}.data", data_field, v))
continue
if isinstance(v, np.ndarray) and v.dtype == object:
vv = _unwrap_singleton(v)
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
candidates.append((k, vv, None))
continue
data2 = _try_get_struct_field(vv, "data")
if data2 is not None:
data2 = _unwrap_singleton(data2)
if isinstance(data2, np.ndarray) and data2.ndim == 2:
candidates.append((f"{k}.data", data2, vv))
continue
if not candidates:
raise RuntimeError(f"mat 里没找到可用 EEG 二维矩阵或 struct.data{mat_path}")
def score(arr: np.ndarray) -> int:
s = 0
if 64 in arr.shape: s += 10
if 32 in arr.shape: s += 9
if 128 in arr.shape: s += 8
if 129 in arr.shape: s += 7
s += int(np.prod(arr.shape) // 100000)
return s
candidates.sort(key=lambda x: score(x[1]), reverse=True)
key, eeg, st = candidates[0]
st_for_meta = st
eeg = np.asarray(_unwrap_singleton(eeg), dtype=np.float32)
if eeg.ndim != 2:
raise RuntimeError(f"解析结果不是二维: key={key}, shape={eeg.shape}, file={mat_path}")
# 统一成 (T, C)
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
eeg = eeg.T
elif eeg.shape[1] in (32, 64, 128, 129):
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
eeg = eeg.T
fs = DEFAULT_FS
ch_names = None
xyz = None
if st_for_meta is not None:
fs2 = _extract_sample_rate(st_for_meta)
if fs2 is not None and fs2 > 1:
fs = float(fs2)
ch_names = _extract_electrode_names(st_for_meta)
xyz = _extract_xyz(st_for_meta)
eeg_uV, msg, warn = _auto_scale_to_uV(eeg)
print(msg)
if warn:
print(warn)
return eeg_uV.astype(np.float32), float(fs), ch_names, xyz
# ==========================
# 预测接口:导入 predict_hc_mdd
# ==========================
def _predict_label_by_model(model_path: str, mat_dir: str) -> str:
try:
from infer_pth import predict_hc_mdd
except Exception as e:
raise RuntimeError(
"无法导入 predict_hc_mdd请确保 pre.py 或 infer_pth.py 与本文件同目录)。\n"
f"原始错误: {e}"
)
try:
out = predict_hc_mdd(mat_dir, model_path)
except TypeError:
out = predict_hc_mdd(model_path, mat_dir)
label = str(out.get("pred_label", "")).strip().upper()
if label not in ("HC", "MDD"):
raise RuntimeError(f"predict_hc_mdd 返回 pred_label 非法: {label},原始返回: {out}")
return label
# ==========================
# 通道分区
# ==========================
def _norm_name(s: str) -> str:
return str(s).strip().upper().replace(" ", "")
def build_channel_index_map(ch_names, n_channels: int):
if not ch_names or len(ch_names) != n_channels:
return {}
return {_norm_name(nm): i for i, nm in enumerate(ch_names)}
def pick_indices_by_names(name_to_idx, names):
idx = []
for n in names:
nn = _norm_name(n)
if nn in name_to_idx:
idx.append(name_to_idx[nn])
return sorted(list(set(idx)))
def _fallback_region_indices(n_channels: int):
a = int(n_channels * 0.33)
b = int(n_channels * 0.66)
frontal = list(range(0, a))
central = list(range(a, b))
parietal = list(range(b, n_channels))
prefrontal = list(range(0, max(2, a // 2)))
posterior = list(range(b, n_channels))
left = [i for i in range(n_channels) if i % 2 == 0]
right = [i for i in range(n_channels) if i % 2 == 1]
return frontal, central, parietal, prefrontal, posterior, left, right
def get_region_indices(name_to_idx, n_channels: int):
if not name_to_idx:
return _fallback_region_indices(n_channels)
central_names = ["CZ","C1","C2","C3","C4","C5","C6","CP1","CP2","CP3","CP4","CP5","CP6","FC1","FC2","FC3","FC4","FC5","FC6"]
frontal_names = ["FZ","F1","F2","F3","F4","F5","F6","F7","F8","AF3","AF4","AF7","AF8","FPZ","FP1","FP2","FCZ"]
parietal_names = ["PZ","P1","P2","P3","P4","P5","P6","POZ","PO3","PO4","PO5","PO6","PO7","PO8","CPZ"]
prefrontal_names = ["FP1","FP2","FPZ","AF3","AF4","AF7","AF8"]
posterior_names = ["O1","O2","OZ","PO7","PO8","PO3","PO4","PZ","P3","P4","P1","P2"]
central = pick_indices_by_names(name_to_idx, central_names)
frontal = pick_indices_by_names(name_to_idx, frontal_names)
parietal = pick_indices_by_names(name_to_idx, parietal_names)
prefrontal = pick_indices_by_names(name_to_idx, prefrontal_names)
posterior = pick_indices_by_names(name_to_idx, posterior_names)
left_names = ["FP1","AF3","AF7","F3","F5","F7"]
right_names = ["FP2","AF4","AF8","F4","F6","F8"]
left = pick_indices_by_names(name_to_idx, left_names)
right = pick_indices_by_names(name_to_idx, right_names)
if not (central and frontal and parietal and prefrontal and posterior):
fb = _fallback_region_indices(n_channels)
frontal2, central2, parietal2, prefrontal2, posterior2, left2, right2 = fb
frontal = frontal if frontal else frontal2
central = central if central else central2
parietal = parietal if parietal else parietal2
prefrontal = prefrontal if prefrontal else prefrontal2
posterior = posterior if posterior else posterior2
left = left if left else left2
right = right if right else right2
return frontal, central, parietal, prefrontal, posterior, left, right
# ==========================
# Welch PSD + band power
# ==========================
def welch_psd(eeg_tc: np.ndarray, fs: float):
nperseg = min(1024, eeg_tc.shape[0])
if nperseg < 128:
nperseg = min(256, eeg_tc.shape[0])
freqs, pxx = signal.welch(
eeg_tc, fs=fs, nperseg=nperseg, noverlap=nperseg // 2,
axis=0, scaling="density",
)
return freqs, pxx
def band_power_from_psd(freqs, pxx_fc, band):
lo, hi = band
m = (freqs >= lo) & (freqs < hi)
if not np.any(m):
return np.zeros((pxx_fc.shape[1],), dtype=np.float32)
# 兼容处理numpy 2.0+ 推荐使用 trapezoid旧版本用 trapz
if hasattr(np, "trapezoid"):
return np.trapezoid(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
else:
return np.trapz(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
def region_mean_power(freqs, pxx_fc, idx, band) -> float:
if not idx:
return 0.0
pw = band_power_from_psd(freqs, pxx_fc, band)
return float(np.mean(pw[idx]))
def compute_iaf(freqs, pxx_fc, posterior_idx):
lo, hi = BANDS_METRICS["Alpha"]
m = (freqs >= lo) & (freqs <= hi)
if not np.any(m) or not posterior_idx:
return 0.0
spec = np.mean(pxx_fc[:, posterior_idx], axis=1)
sub = spec[m]
fsub = freqs[m]
return float(fsub[int(np.argmax(sub))])
# ==========================
# 图EEG波形、PSD
# ==========================
def plot_eeg_waveforms(data_uv_tc: np.ndarray, fs: float, ch_names, out_dir: str, seconds: int = 10):
"""
固定用 FIXED_EEG_IDXS 画 EEG.png按重要性排序
data_uv_tc: (T, C) μV
"""
T, C = data_uv_tc.shape
# 1) 过滤越界索引(避免你的数据通道数不足时报错)
idxs = [i for i in FIXED_EEG_IDXS if 0 <= i < C]
if len(idxs) < len(FIXED_EEG_IDXS):
missing = [i for i in FIXED_EEG_IDXS if not (0 <= i < C)]
print(f"[WARN] Some fixed EEG indices out of range (C={C}): {missing}")
if len(idxs) == 0:
raise RuntimeError(f"No valid indices in FIXED_EEG_IDXS for current data (C={C}).")
# 2) 通道显示
picked_names = []
for idx in idxs:
# 找 idx 在 FIXED_EEG_IDXS 的位置,用对应标签
pos = FIXED_EEG_IDXS.index(idx)
std_label = FIXED_EEG_LABELS[pos] if pos < len(FIXED_EEG_LABELS) else f"CH{idx}"
if ch_names and idx < len(ch_names):
picked_names.append(f"{std_label}")
else:
picked_names.append(std_label)
# 3) 截取前 seconds 秒
max_samples = int(min(T, seconds * fs))
x = np.arange(max_samples) / fs
fig_h = 1.4 * len(idxs) + 1
fig, axes = plt.subplots(len(idxs), 1, figsize=(10, fig_h), sharex=True)
if len(idxs) == 1:
axes = [axes]
# 4) 分位数定范围,避免尖峰撑爆
seg = data_uv_tc[:max_samples, idxs].T # (n_ch, samples)
lo = float(np.percentile(seg, 1))
hi = float(np.percentile(seg, 99))
m = max(abs(lo), abs(hi))
m = max(m, 50.0)
for ax, ch_idx, nm in zip(axes, idxs, picked_names):
y = data_uv_tc[:max_samples, ch_idx]
ax.plot(x, y, linewidth=1.2)
ax.set_ylabel("μV")
ax.set_title(nm, loc="left", fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_ylim(-m, m)
axes[-1].set_xlabel("Time (s)")
plt.tight_layout()
out_path = os.path.join(out_dir, "EEG.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] EEG waveform saved: {out_path}")
def plot_psd(eeg_uV_tc, fs, ch_names, out_dir):
C = eeg_uV_tc.shape[1]
chosen_idx = []
if ch_names:
mp = {n.upper(): i for i, n in enumerate(ch_names)}
for p in ["C3","C4","CZ"]:
if p in mp:
chosen_idx.append(mp[p])
if len(chosen_idx) < 3:
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
stds.sort(key=lambda x: x[1], reverse=True)
for i, _ in stds:
if i not in chosen_idx:
chosen_idx.append(i)
if len(chosen_idx) == 3:
break
chosen_name = [ch_names[i] for i in chosen_idx]
else:
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
stds.sort(key=lambda x: x[1], reverse=True)
chosen_idx = [i for i, _ in stds[:3]]
chosen_name = [f"CH{i}" for i in chosen_idx]
fig = plt.figure(figsize=(7.5, 4.8))
for idx, nm in zip(chosen_idx, chosen_name):
f, pxx = signal.welch(eeg_uV_tc[:, idx], fs=fs, nperseg=int(2*fs), noverlap=int(1*fs))
mask = (f >= PSD_FMIN) & (f <= PSD_FMAX)
p_db = 10 * np.log10(pxx[mask] + 1e-20)
plt.plot(f[mask], p_db, linewidth=1.8, label=nm)
plt.xlabel("Hz")
plt.ylabel("Power (dB)")
plt.title("PSD")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
out_path = os.path.join(out_dir, "psd.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] psd.png -> {out_path}")
# ==========================
# Topomap如果有 xyz
# ==========================
def build_mne_raw_from_uV(eeg_uV_tc, fs, ch_names, xyz):
C = eeg_uV_tc.shape[1]
if not ch_names:
ch_names = [f"CH{i+1}" for i in range(C)]
data_v_ct = eeg_uV_tc.T * 1e-6 # (C,T) V
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"] * C)
raw = mne.io.RawArray(data_v_ct, info, verbose=False)
if xyz is not None and isinstance(xyz, np.ndarray) and xyz.shape == (C, 3):
try:
ch_pos = {ch_names[i]: xyz[i, :] for i in range(C)}
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame="head")
raw.set_montage(montage, on_missing="ignore")
except Exception as e:
print(f"[WARN] set_montage failed (ignore): {e}")
else:
print("[WARN] electrode_xyz missing/invalid -> skip topomap")
return raw
def _raw_has_positions(raw):
try:
locs = np.array([ch["loc"][:3] for ch in raw.info["chs"]])
ok = np.isfinite(locs).all() and (np.linalg.norm(locs, axis=1) > 0).any()
return bool(ok)
except Exception:
return False
def compute_band_powers_for_topomap(raw, bands):
data = raw.get_data() # (C,T) V
fs = raw.info["sfreq"]
psds, freqs = mne.time_frequency.psd_array_welch(
data, sfreq=fs,
fmin=min(v[0] for v in bands.values()),
fmax=max(v[1] for v in bands.values()),
n_fft=int(2 * fs),
n_overlap=int(1 * fs),
average="mean",
verbose=False
)
out = {}
for k, (fmin, fmax) in bands.items():
idx = np.where((freqs >= fmin) & (freqs <= fmax))[0]
# 兼容处理numpy 2.0+ 推荐使用 trapezoid旧版本用 trapz
if hasattr(np, "trapezoid"):
bp = np.trapezoid(psds[:, idx], freqs[idx], axis=1) # (C,)
else:
bp = np.trapz(psds[:, idx], freqs[idx], axis=1) # (C,)
v = np.log10(bp + 1e-30)
v = v - np.mean(v)
out[k] = v
return out
def plot_average_topomap(raw, values, out_dir):
fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.6))
im, _ = mne.viz.plot_topomap(values, raw.info, axes=ax, show=False, contours=0,sphere=(0, 0, 0, 0.11))
ax.set_title("0.8-30 Hz", fontsize=12)
plt.colorbar(im, ax=ax, shrink=0.85)
plt.tight_layout()
out_path = os.path.join(out_dir, "average_topomap.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] average_topomap.png -> {out_path}")
def plot_band_topomaps(raw, band_values, out_dir):
order = [
("delta", "δ (0.8-3.9Hz)"),
("theta", "θ (4-7.9Hz)"),
("alpha", "α (8-12.9Hz)"),
("beta", "β (13-30Hz)"),
("broad", "0.8-30 Hz"),
]
fig, axes = plt.subplots(1, 5, figsize=(16, 4.2))
ims = []
for ax, (k, title) in zip(axes, order):
im, _ = mne.viz.plot_topomap(band_values[k], raw.info, axes=ax, show=False, contours=0,extrapolate='head',sphere=(0, 0, 0, 0.11))
ax.set_title(title, fontsize=11)
ims.append(im)
fig.subplots_adjust(left=0.02, right=0.85, top=0.88, bottom=0.05, wspace=0.35)
cax = fig.add_axes([0.87, 0.15, 0.015, 0.7])
fig.colorbar(ims[-1], cax=cax)
out_path = os.path.join(out_dir, "topomaps.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] topomaps.png -> {out_path}")
# ==========================
# 生成 ResultData.txt
# ==========================
def compute_and_save_txt(model_path, mat_dir, out_dir, eeg_uV_tc, fs, ch_names):
pred_label = _predict_label_by_model(model_path, mat_dir)
recommend = "" if pred_label == "MDD" else ""
T, C = eeg_uV_tc.shape
mp = build_channel_index_map(ch_names, C)
frontal_idx, central_idx, parietal_idx, prefrontal_idx, posterior_idx, left_idx, right_idx = \
get_region_indices(mp, C)
freqs, pxx = welch_psd(eeg_uV_tc, fs)
central_alpha = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Alpha"])
central_beta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Beta"])
frontal_alpha = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Alpha"])
frontal_beta = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Beta"])
par_alpha = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Alpha"])
par_beta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Beta"])
central_ab = (central_alpha / (central_beta + EPS)) if central_beta > 0 else 0.0
frontal_ab = (frontal_alpha / (frontal_beta + EPS)) if frontal_beta > 0 else 0.0
par_ab = (par_alpha / (par_beta + EPS)) if par_beta > 0 else 0.0
central_theta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Theta"])
par_theta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Theta"])
central_tb = (central_theta / (central_beta + EPS)) if central_beta > 0 else 0.0
par_tb = (par_theta / (par_beta + EPS)) if par_beta > 0 else 0.0
if not left_idx or not right_idx:
left_idx = [i for i in prefrontal_idx if (i % 2 == 0)]
right_idx = [i for i in prefrontal_idx if (i % 2 == 1)]
left_alpha = region_mean_power(freqs, pxx, left_idx, BANDS_METRICS["Alpha"])
right_alpha = region_mean_power(freqs, pxx, right_idx, BANDS_METRICS["Alpha"])
prefrontal_alpha_asym = float(np.log(right_alpha + EPS) - np.log(left_alpha + EPS))
iaf = compute_iaf(freqs, pxx, posterior_idx)
pre_td = region_mean_power(freqs, pxx, prefrontal_idx, (BANDS_METRICS["Delta"][0], BANDS_METRICS["Theta"][1]))
pre_total = region_mean_power(freqs, pxx, prefrontal_idx, TOTAL_POWER_BAND)
pre_td_rel = (pre_td / (pre_total + EPS)) * 100.0 if pre_total > 0 else 0.0
def f1(x): return f"{x:.1f}"
txt = (
f"中央区α/β波比值:{f1(central_ab)}\n"
f"额区α/β波比值:{f1(frontal_ab)}\n"
f"顶区α/β波比值:{f1(par_ab)}\n"
f"中央区θ/β波比值:{f1(central_tb)}\n"
f"顶区θ/β波比值:{f1(par_tb)}\n"
f"前额叶α波不对称性:{f1(prefrontal_alpha_asym)}\n"
f"个体化α峰值频率:{f1(iaf)}\n"
f"前额叶θ+δ波功率:{f1(pre_td_rel)}\n"
f"是否推荐治疗:{recommend}\n"
)
out_path = os.path.join(out_dir, "ResultData.txt")
with open(out_path, "w", encoding="utf-8") as f:
f.write(txt)
print(f"[OK] ResultData.txt -> {out_path}")
# ==========================
# 一个函数一次性跑完txt + 图片)
# ==========================
def run_all(model_path: str, mat_dir: str, out_root: str, seconds: int = EEG_PLOT_SECONDS):
# 1) 选第一个 mat
if not os.path.exists(mat_dir):
raise RuntimeError(f"输入目录不存在: {mat_dir}")
mats = [f for f in os.listdir(mat_dir) if f.lower().endswith(".mat")]
if not mats:
raise RuntimeError(f"mat_dir 下找不到 .mat: {mat_dir}")
mats.sort()
mat_file = os.path.join(mat_dir, mats[0])
print(f"[INFO] Found mat: {mat_file}")
# 2) 创建输出目录
out_dir = ensure_outdir(out_root)
print(f"[INFO] Output dir: {out_dir}")
# --- 总是进行预处理 (默认模式) ---
print("[INFO] Mode: Raw Data (Default). Running preprocessing...")
temp_dir = os.path.join(out_dir, "temp_preprocessed")
mat_file = preprocess_mat_file(mat_file, temp_dir)
# 更新 mat_dir 指向临时目录(为了传给 compute_and_save_txt 里的 predict 接口)
mat_dir = temp_dir
# 3) 读 EEGμV
eeg_uV_tc, fs, ch_names, xyz = load_eeg_from_mat(mat_file)
print(f"[INFO] eeg shape(T,C)={eeg_uV_tc.shape}, fs={fs}")
# 5) 画图PSD + EEG
plot_psd(eeg_uV_tc, fs, ch_names, out_dir)
plot_eeg_waveforms(eeg_uV_tc, fs, ch_names, out_dir, seconds=seconds)
# 6) topomap有 xyz 才画)
try:
raw = build_mne_raw_from_uV(eeg_uV_tc, fs, ch_names, xyz)
if _raw_has_positions(raw):
band_vals = compute_band_powers_for_topomap(raw, BANDS_TOPOMAP)
plot_average_topomap(raw, band_vals["broad"], out_dir)
plot_band_topomaps(raw, band_vals, out_dir)
else:
print("[WARN] No valid positions -> skip topomap.")
except Exception as e:
print(f"[WARN] topomap failed -> skip. reason: {e}")
# 4) 指标写 txt
compute_and_save_txt(model_path, mat_dir, out_dir, eeg_uV_tc, fs, ch_names)
print("[DONE] txt + figures generated.")
return out_dir
if __name__ == "__main__":
import multiprocessing
multiprocessing.freeze_support()
import argparse
import sys
# 1. 路径锚定:获取资源绝对路径
def get_resource_path(relative_path):
"""
获取资源的绝对路径。
策略优先在当前执行目录EXE所在目录寻找。
这适用于“绿色软件”模式即资源文件model/raw_data直接放在EXE旁边。
"""
if getattr(sys, 'frozen', False):
# PyInstaller 打包后的 EXE 所在目录
base_path = os.path.dirname(sys.executable)
else:
# 开发环境:当前脚本所在目录
base_path = os.path.dirname(os.path.abspath(__file__))
return os.path.join(base_path, relative_path)
# 设置默认路径
DEFAULT_MODEL = get_resource_path(os.path.join("model", "Model_1.pth"))
# 这里我们保持 mat_dir 和 out_root 相对于 EXE 所在目录(或当前工作目录)
if getattr(sys, 'frozen', False):
EXE_DIR = os.path.dirname(sys.executable)
else:
EXE_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_MAT = os.path.join(EXE_DIR, "raw_data")
DEFAULT_OUT = os.path.join(EXE_DIR, "out")
# 2. 解析命令行参数
parser = argparse.ArgumentParser(description="EEG Depression Assessment Algorithm Integration")
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL, help="模型文件的路径 (.pth)")
parser.add_argument("--mat_dir", type=str, default=DEFAULT_MAT, help="输入文件夹路径 (包含原始EEG .mat)")
parser.add_argument("--out_root", type=str, default=DEFAULT_OUT, help="结果输出的根目录")
parser.add_argument("--seconds", type=int, default=10, help="画波形图截取的秒数")
args = parser.parse_args()
# 3. 检查关键路径
if not os.path.exists(args.mat_dir):
print(f"[WARN] 输入文件夹不存在: {args.mat_dir}")
if not os.path.exists(args.model_path):
print(f"[WARN] 模型文件不存在: {args.model_path}")
# 4. 执行主流程
print(f"[*] 运行配置:")
print(f" - Model : {args.model_path}")
print(f" - Input : {args.mat_dir}")
print(f" - Output: {args.out_root}")
print(f" - Mode : RAW (Auto Preprocess)")
run_all(args.model_path, args.mat_dir, args.out_root, seconds=args.seconds)

View File

@@ -0,0 +1,16 @@
import ctypes
import sys
def is_program_running(name='Global\\Parser_main'):
# 创建互斥体
mutex_name =name
h_mutex = ctypes.windll.kernel32.CreateMutexW(None, False, mutex_name)
# 检查互斥体是否已经存在
if ctypes.windll.kernel32.GetLastError() == 183: # ERROR_ALREADY_EXISTS
print("程序已经在运行.")
return True
return False

View File

@@ -0,0 +1,379 @@
# -*-coding:utf-8 -*-
'''
SunnyLinker的通讯驱动
'''
import ast
import socket
import threading
import time
import datetime
from typing import Dict
import numpy as np
from threading import Thread, Event
import serial
from scipy import signal
from serial.serialutil import SerialException
from protocol import ProtocolFrame
class RingBuffer:
def __init__(self, n_chan, n_points):
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points))
self.currentPtr = 0
self.readPtr = 0
self.nUpdate = 0
self.rawData = np.zeros((n_chan, 1))
## append buffer and update current pointer
def appendBuffer(self, data):
if self.nUpdate == self.n_points:
raise Exception("Buffer is full")
n = data.shape[1]
# 计算可以写入的元素数量
write_count = min(self.n_points - self.nUpdate, n)
# 写入新数据
self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count]
# 更新结束指针
self.currentPtr = (self.currentPtr + write_count) % self.n_points
# 更新大小
self.nUpdate += write_count
## get data from buffer
def getData(self, count=50):
# 确保不会尝试读取超过缓冲区当前大小的数据
count = min(count, self.nUpdate)
# 计算读取结束后的下一个位置
next_read_ptr = (self.readPtr + count) % self.n_points
if self.readPtr + count <= self.n_points:
# 情况 1不环绕数据是连续的
end_index = next_read_ptr if next_read_ptr != 0 else self.n_points
data = self.buffer[:, self.readPtr:end_index]
else:
# 情况 2发生环绕数据被分成两部分
# 第一部分:从 readPtr 到缓冲区末尾
part1 = self.buffer[:, self.readPtr:]
# 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点
part2 = self.buffer[:, :next_read_ptr]
# 将两部分在列方向上拼接
data = np.concatenate((part1, part2), axis=1)
# 更新读指针
self.readPtr = next_read_ptr
# 更新大小
self.nUpdate -= count
return data
# reset buffer
def resetAllPara(self):
self.nUpdate = 0
self.currentPtr = 0
self.readPtr = 0 # add by lizhenhua 清空读指针
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
class SunnyLinker64(Thread, ):
t_buffer = 10
n_chan = 64
srate = 250
receiveData = b''
toUv=True#转为uV
RingBufferLock = threading.Lock()
# 单例模式
_instance = None
_initialized = False # 检查是否已经初始化
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(SunnyLinker64, cls).__new__(cls)
return cls._instance
def __init__(self, host='127.0.0.1', port=7878, srate=250, n_chan=64,method = 'tcp'):
if SunnyLinker64._initialized:
return
Thread.__init__(self)
self.daemon = True
self.host = host
self.port = port
self.srate = srate
self.n_chan = n_chan
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
self.__ringBuffer = RingBuffer(self.n_chan + 2,
int(np.round(self.t_buffer * self.srate)))
self.energy = 0 # 电量
self.status_code = 0 # 与采集设备通信的状态码0为异常1为正常
self.gain_value = 6 # 增益倍数
# 设置初始化标志为True防止重复初始化
SunnyLinker64._initialized = True
# --- 新增:用于心跳检测 ---
self.last_called = 0 # 初始化为0
self.last_called_lock = threading.Lock() # 保护 last_called 的访问
def set_sampleRate(self,sampleRate_Code=0x00):
'''
设置采样率
:param sampleRate_Code: 0x00:250Hz,0x01:500Hz,0x02:1000Hz,0x03:2000Hz
'''
function_code = 0x02
gain_code = 0x06
sampleRate_Code = [gain_code,sampleRate_Code]
packed_data = ProtocolFrame.pack(function_code, sampleRate_Code)
if self.method == 'tcp':
self.sock.send(packed_data)
def push_trigger(self,label):
'''
数据打标
@param label:标签类别
'''
function_code = None
label = [label]
packed_data = ProtocolFrame.pack(function_code, label)
if self.method == 'tcp' and hasattr(self,'serial'):
print('发送:', label, datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
self.serial.write(packed_data)
def Impedance(self, On):
'''
阻抗检测开关
:param On:True为开启False为关闭
:return: 组好的协议帧
'''
function_code = 0x01
if On:
data = [0x1]
self.gain_value = 6
else:
data = [0x0]
self.gain_value = 6
packed_data = ProtocolFrame.pack(function_code, data)
if self.method == 'tcp':
self.sock.send(packed_data)
def connect(self):
try:
if self.method == 'serial':
# 开启com口波特率115200超时5
self.sock = serial.Serial(self.host, self.port, timeout=5)
self.sock.flushInput() # 清空缓冲区
count = self.sock.inWaiting() # 获取串口缓冲区数据
while not count:
count = self.sock.inWaiting() # 获取串口缓冲区数据
# # 接收和存储数据
data = (self.sock.read(count))
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
elif self.method == 'tcp':
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((self.host, int(self.port)))
self.set_sampleRate(0x00) #设置250Hz采样率
except Exception as e:
print("请打开头环")
print(e)
print("connected")
def extract_packet(self, packet):
# 存储一个点的八通道数据
dataList = []
# 存储116个点的八通道数据
dataMatrix = []
for j in range(5):
for i in range(self.n_chan):
if not self.toUv:#原始数据直接输出
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
194 * j + 25 + 2 + i * 3]
else:#转为uV
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
194 * j + 25 + 2 + i * 3]
if val < 8388608:
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
else:
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
dataList.append(val)
#同步触发源
val = packet[194 * j + 25 + (i+1) * 3]
dataList.append(val)
#同步触发序号
val = packet[194 * j + 25 + (i+1) * 3+1]
dataList.append(val)
# 将数据矩阵进行拼接
if len(dataMatrix) == 0:
dataMatrix = np.asmatrix(dataList)
else:
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
dataList.clear()
return np.transpose(dataMatrix)
def run(self):
self.connect()
self.running = True
self.PackageLength = 998
# 启动心跳检测线程
threading.Thread(target=self.heartbeat_checker, daemon=True).start()
while self.running:
try:
if self.method == 'serial':
count = self.sock.inWaiting() # 获取串口缓冲区数据
if count:
# 接收和存储数据
data = (self.sock.read(count))
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
elif self.method == 'tcp':
data = self.sock.recv(600)
if not data:
break
self.receiveData += data
with self.last_called_lock:
self.last_called = time.time()
self.status_code = 1 # 收到数据,标记为正常
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
b'\x55\x55') >= self.PackageLength - 2:
index = self.receiveData.index(b'\xaa')
self.receiveData = self.receiveData[index:]
if len(self.receiveData) >= self.PackageLength:
onepackage = self.receiveData[:self.PackageLength]
if onepackage[7] != 0:
self.energy = onepackage[7] # 电量
self.receiveData = self.receiveData[self.PackageLength:]
dataMatrix = self.extract_packet(onepackage)
try:
with self.RingBufferLock:
self.__ringBuffer.appendBuffer(dataMatrix)
except Exception as e:
print("锁:写入异常",e)
# self.RingBufferLock.release()
except ConnectionResetError:
self.status_code = 0 # 状态异常
print("Connection was reset by the peer.")
break
self.sock.close()
# --- 新增:心跳检测线程 ---
def heartbeat_checker(self):
"""
定期检查是否在最近2秒内收到 eegData
如果超过2秒未收到则设置 status_code = 0
"""
while self.running:
time.sleep(0.5) # 每0.5秒检查一次
with self.last_called_lock:
now = time.time()
# 只有收到过一次数据后才开始判断超时
if self.last_called > 0 and (now - self.last_called) > 2:
if self.status_code != 0:
print("EEG data timeout: disconnected")
self.status_code = 0
def getImpedance(self, data,n_chan):
'''
获取阻抗值已经放大100倍单位是kΩ
@param data: 准备计算的通道数据每通道200个值注意不要把信号打标的通道传进来
@return: 返回各个通道的阻抗值
'''
impedanceList = []
data = data[:n_chan]
for channelindex in range(data.shape[0]):
if len(data[channelindex]) > 0:
data_list = []
# 设计陷波滤波器去除50Hz成分
is50filter = True
if is50filter:
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽1000是采样频率
data_list = signal.lfilter(b, a, data[channelindex].tolist())
else:
data_list.extend(data[channelindex].tolist())
data_list = data_list[-1000:]
# 执行FFT
fft_result = np.fft.fft(data_list)
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
# [fft_magnitude[-1] / len(t[0].tolist())]))
# 找到幅值最大的频率成分的索引忽略直流分量即索引0
max_index = np.argmax(fft_magnitude[1:])
# 获取最大幅值的频率索引加上1因为索引0是直流分量
freq_index = max_index + 1
# 获取最大幅值
max_magnitude = fft_magnitude[freq_index]
# 阻抗
import math
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
result *= 0.44 * 100 # 统一放大100倍
impedanceList.append(int(result))
# print(max_magnitude, result)
else:
impedanceList.append(0)
impedances = np.array(impedanceList)
return impedances
def getData(self,count):
'''
获取最新的数据
@param count: 每通道返回的最数值数目
@return: 所有通道的最新count个数值
'''
data=None
try:
with self.RingBufferLock:
data = self.__ringBuffer.getData(count)
except:
print("锁:读取异常")
# self.RingBufferLock.release()
return data
def GetDataLenCount(self):
'''
获取最新缓存中每个通道的数量
@return:
'''
return self.__ringBuffer.nUpdate
def ResetAll(self):
'''
清空缓存
@return:
'''
with self.RingBufferLock:
self.__ringBuffer.resetAllPara()
def stop(self):
self.running = False
if __name__ == "__main__":
# Usage
Linker = SunnyLinker64('127.0.0.1', 5086, 1000, 65)
Linker.start()
try:
while True:
time.sleep(0.005)
if(Linker.count()>0):
# print(Linker.ringBuffer.nUpdate)
t = Linker.getData()
print(t.shape[1], Linker.count())
# Linker.ringBuffer.nUpdate=0
# time.sleep(0.2)
except KeyboardInterrupt:
Linker.stop()

View File

@@ -0,0 +1,113 @@
# -*- mode: python ; coding: utf-8 -*-
import sys
import os
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
# ========================================================
# 1. 工程配置区 (Project Config)
# ========================================================
block_cipher = None
ENTRY_POINT = 'start_parse.py'
APP_NAME = 'start_parse' # 打包后生成的文件夹名和 exe 名
# ========================================================
# 2. 依赖分析 (Dependency Analysis)
# ========================================================
hidden_imports = [
# eegParser 依赖
'numpy',
'numpy.lib.stride_tricks',
'pandas',
'scipy',
'scipy.io',
'scipy.io.savemat',
'scipy.signal',
# SunnyLinker 依赖
'serial',
'serial.serialutil',
'socket',
# zmq 通信依赖
'zmq',
'zmq.asyncio',
# 其他可能遗漏的模块
'threading',
'datetime',
]
# 收集 zmq 的所有子模块
try:
hidden_imports += collect_submodules('zmq')
except:
pass
# ========================================================
# 3. 资源锚定 (Data Anchoring)
# ========================================================
# 打包时需要包含的资源文件
datas = [
('xy_64.xlsx', '.'), # 电极位置文件
]
# 收集 mne 的数据文件(如果有)
try:
datas += collect_data_files('mne')
except:
pass
# ========================================================
# 4. 构建流程 (Build Process)
# ========================================================
a = Analysis(
[ENTRY_POINT],
pathex=[],
binaries=[],
datas=datas,
hiddenimports=hidden_imports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython', 'matplotlib'],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name=APP_NAME,
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
# ========================================================
# 5. 打包模式: OneDir (单文件夹)
# ========================================================
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=True,
upx_exclude=[],
name=APP_NAME,
)

View File

@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
"""
打包脚本 - datacollect
用于将 EEG 数据采集程序打包为独立的 exe 文件
"""
import os
import sys
import shutil
import subprocess
def main():
# 1. 定义路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DIST_DIR = os.path.join(BASE_DIR, 'dist')
BUILD_DIR = os.path.join(BASE_DIR, 'build')
APP_NAME = 'start_parse'
# 2. 清理旧构建
print("[1/3] Cleaning up old builds...")
for dir_path in [DIST_DIR, BUILD_DIR]:
if os.path.exists(dir_path):
try:
shutil.rmtree(dir_path)
print(f" Cleaned {os.path.basename(dir_path)}/")
except Exception as e:
print(f" Warning: Could not clean {dir_path}: {e}")
# 3. 检查必要文件
print("\n[2/3] Checking required files...")
required_files = ['start_parse.py', 'eegParser.py', 'build_algorithm.spec', 'xy_64.xlsx']
for f in required_files:
path = os.path.join(BASE_DIR, f)
if os.path.exists(path):
print(f"{f}")
else:
print(f"{f} NOT FOUND!")
sys.exit(1)
# 4. 运行 PyInstaller
print("\n[3/3] Running PyInstaller...")
spec_file = os.path.join(BASE_DIR, 'build_algorithm.spec')
cmd = [
sys.executable,
"-m", "PyInstaller",
spec_file,
"--clean",
"--noconfirm"
]
try:
subprocess.check_call(cmd, cwd=BASE_DIR)
except subprocess.CalledProcessError as e:
print(f"\n✗ PyInstaller failed with error code: {e.returncode}")
sys.exit(1)
# 5. 验证结果
exe_path = os.path.join(DIST_DIR, APP_NAME, f'{APP_NAME}.exe')
if os.path.exists(exe_path):
size_mb = os.path.getsize(exe_path) / (1024 * 1024)
print(f"\n{'='*50}")
print(f"✓ SUCCESS! Executable created:")
print(f" {exe_path}")
print(f" Size: {size_mb:.1f} MB")
print(f"{'='*50}")
print(f"\n部署说明:")
print(f" 1. 复制 dist/start_parse 文件夹到目标电脑")
print(f" 2. 确保目标电脑已安装 EEG 设备的 USB 驱动")
print(f" 3. 运行 start_parse.exe")
else:
print("\n✗ Build failed - executable not found")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,72 @@
import os
import shutil
import subprocess
import sys
def main():
# 1. 定义路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DIST_DIR = os.path.join(BASE_DIR, 'dist')
APP_NAME = 'Depression_Decoder'
MODEL_SRC = os.path.join(BASE_DIR, 'model')
RAW_DATA_SRC = os.path.join(BASE_DIR, 'raw_data')
# 2. 清理旧构建
print("[1/3] Cleaning up old builds...")
if os.path.exists(DIST_DIR):
try:
shutil.rmtree(DIST_DIR)
print(" Cleaned dist/")
except Exception as e:
print(f" Warning: Could not clean dist/: {e}")
BUILD_DIR = os.path.join(BASE_DIR, 'build')
if os.path.exists(BUILD_DIR):
try:
shutil.rmtree(BUILD_DIR)
print(" Cleaned build/")
except Exception as e:
print(f" Warning: Could not clean build/: {e}")
# 3. 运行 PyInstaller
print("[2/3] Running PyInstaller...")
cmd = [
"pyinstaller",
"build_algorithm.spec",
"--clean",
"--noconfirm"
]
try:
subprocess.check_call(cmd, shell=True)
except subprocess.CalledProcessError:
print("Error: PyInstaller failed.")
sys.exit(1)
# 4. 复制外部资源 (如果存在)
print("[3/3] Copying external resources...")
# 确保 dist 目录存在 (pyinstaller 应该已经创建了)
if not os.path.exists(DIST_DIR):
os.makedirs(DIST_DIR)
for src_path, folder_name in [(MODEL_SRC, 'model'), (RAW_DATA_SRC, 'raw_data')]:
dst_path = os.path.join(DIST_DIR, folder_name)
if os.path.exists(src_path):
try:
if os.path.exists(dst_path):
shutil.rmtree(dst_path)
shutil.copytree(src_path, dst_path)
print(f" Copied {folder_name} to dist/")
except Exception as e:
print(f" Error copying {folder_name}: {e}")
else:
print(f" Note: {folder_name} source not found at {src_path}, skipping.")
print("\n" + "="*50)
print(f"SUCCESS! Executable is in: {DIST_DIR}")
print("="*50)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,427 @@
import os
import sys
import threading
import time
import numpy as np
import pandas as pd
from SunnyLinker import SunnyLinker64
from zmqServer import zmqServer
from zmqClient import zmqClient
from scipy.io import savemat
from scipy import signal
class Parser_main(threading.Thread):
def __init__(self):
threading.Thread.__init__(self)
self.Running = True
self.fs = 250 # 采样率
self.energy = 0 # 电量
self.status_code = 0 # 与采集设备通信的状态码0为异常1为正常
self.n_chan = 64
self.dataBuffer = []
self.file_num = 0#保存文件序号
self.subject_id = None #受试者ID
self.session_id = None #Session ID
self.last_print_time = None
# 预处理参数
self.enable_preprocess = True # 是否启用预处理
self.lowcut = 0.5 # 高通滤波截止频率 (Hz)
self.highcut = 50 # 低通滤波截止频率 (Hz)
self.notch_freq = 50 # 工频陷波频率 (Hz)
self.ref_chan_name = 'CPZ' # 参考电极名称
self.ref_chan_idx = None # 参考电极索引(运行时确定)
self._init_filter_cache() # 初始化滤波器缓存
# 单位转换参数
self.calibration_scale = 1.0 # 校准系数,用于修正单位转换误差
self.calibration_offset = 0 # 校准偏移量
self._conversion_verified = False # 是否已验证转换
def connect(self):
self.thread_data_server = SunnyLinker64('127.0.0.1', 7878, 250, 64,
method='tcp')
self.thread_data_server.toUv = True
self.thread_data_server.start()
self.zmqServer = zmqServer()
self.zmqServer.start()
self.zmqClient = zmqClient('127.0.0.1', 8088)
self.zmqClient.connect()
def run(self):
while self.Running:
# 同步信息
if self.zmqServer.state_mode == 'sync':
self.zmqClient.send_to_all('sync', self.zmqClient.state)
self.zmqServer.state_mode = 'rest'
# 状态异常,报告上位机
if self.status_code != self.thread_data_server.status_code:
self.status_code = self.thread_data_server.status_code
self.zmqClient.send_to_all('status_code', int(self.status_code))
# 返回电量
if self.energy != self.thread_data_server.energy:
self.energy = self.thread_data_server.energy
self.zmqClient.send_to_all('energy', int(self.energy))
# 更新文件序号
if self.subject_id != self.zmqServer.subject_id or self.session_id != self.zmqServer.session_id:
self.subject_id = self.zmqServer.subject_id
self.session_id = self.zmqServer.session_id
self.file_num = 0 #从零开始计数
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
self.thread_data_server.Impedance(True)
self.zmqServer.open_Impedance = -1
elif self.zmqServer.open_Impedance == False:
self.thread_data_server.Impedance(False)
self.zmqServer.open_Impedance = -1
if self.zmqServer.get_Impedance: # 返回阻抗值
if self.thread_data_server.GetDataLenCount() > self.fs:
Impe_data = self.thread_data_server.getData(self.fs)
# 计算阻抗
imps = self.thread_data_server.getImpedance(Impe_data, self.n_chan)
self.zmqClient.send_to_all('impedance', imps.tolist())
else:
pass
if self.thread_data_server.GetDataLenCount() < 50:
time.sleep(0.01)
continue
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
data = self.thread_data_server.getData(50)
data = data[:self.n_chan, :]
# 数据质量检查与预处理
if self.enable_preprocess:
# 1. 首先验证和校准单位转换
data, calibrated = self.verify_and_calibrate_unit(data)
if calibrated:
print('[INFO] 单位转换已自动校准')
# 2. 检查数据质量
issues = self.check_data_quality(data)
if issues:
print('[警告] 检测到数据质量问题:')
for issue in issues:
print(f' - {issue}')
print('[INFO] 正在进行信号预处理...')
# 3. 执行预处理
data = self.preprocess_data(data)
# 4. 预处理后验证
if issues:
new_issues = self.check_data_quality(data)
if not new_issues:
print(f'[INFO] 预处理完成,数据幅度正常: {np.max(np.abs(data)):.2f} µV')
else:
print('[警告] 预处理后仍存在问题:')
for issue in new_issues:
print(f' - {issue}')
if self.zmqServer.mat_generate:
# 检测是否需要重置缓冲区(第二次发送 matGenerate 时清空旧数据)
if self.zmqServer.reset_mat_buffer:
self.dataBuffer = []
self.last_print_time = None
self.zmqServer.reset_mat_buffer = False
print('[INFO] 数据缓冲区已重置,从头开始采集')
self.dataBuffer.append(data)
if len(self.dataBuffer) % 50 == 0:
current_time = time.time()
if self.last_print_time is not None:
elapsed_time = current_time - self.last_print_time
# 2500个点 = 50个数据块 * 50个采样点/数据块
actual_fs = 2500 / elapsed_time
print(f"接收 2500 个采样点耗时: {elapsed_time:.4f} 秒, 折合实际采样率: {actual_fs:.2f} Hz")
else:
print("开始计时...")
self.last_print_time = current_time
print('数据保存进度: {}/{}'.format(len(self.dataBuffer),int(self.zmqServer.save_win*self.fs//50)))
if len(self.dataBuffer) >= int(self.zmqServer.save_win*self.fs//50): #5分钟*60秒*250Hz / 50
self.zmqServer.mat_generate = False
matData = np.hstack(self.dataBuffer[:int(self.zmqServer.save_win*self.fs//50)])
self.dataBuffer = []
self.last_print_time = None # 重置计时器以备下次使用
self.pack2mat(matData,self.subject_id,self.session_id)
def pack2mat(self,data,subject_id,session_id):
#EEG数据
Data = data.T
#通道名称
channel_names = np.array(
['AIN1', 'AIN2', 'AIN3', 'AIN4', 'AIN5', 'AIN6', 'AIN7', 'AIN8', 'AIN9', 'AIN10', 'AIN11', 'AIN12',
'AIN13', 'AIN14', 'AIN15', 'AIN16', 'AIN17', 'AIN18', 'AIN19', 'AIN20', 'AIN21', 'AIN22', 'AIN23',
'AIN24', 'AIN25', 'AIN26', 'AIN27', 'AIN28', 'AIN29', 'AIN30', 'AIN31', 'AIN32', 'AIN33', 'AIN34',
'AIN35', 'AIN36', 'AIN37', 'AIN38', 'AIN39', 'AIN40', 'AIN41', 'AIN42', 'AIN43', 'AIN44', 'AIN45',
'AIN46', 'AIN47', 'AIN48', 'AIN49', 'AIN50', 'AIN51', 'AIN52', 'AIN53', 'AIN54', 'AIN55', 'AIN56',
'AIN57', 'AIN58', 'AIN59', 'AIN60', 'AIN61', 'AIN62', 'AIN63', 'AIN64'], dtype=object)
#采样率
sample_rate = self.fs
#通道数量
node_number = Data.shape[1]
# 时间轴
t = np.linspace(0, self.zmqServer.save_win, Data.shape[0])
t = t.reshape(len(t), 1)
#电极名称
electrode_name = np.array(['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1',
'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6',
'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1','T7',
'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ','PZ',
'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1'],
dtype=object)
#电极三维坐标
electrode_xyz = self.read_ch_pos()
electrode_xyz.update({'A1': [-0.095, 0, -0.005]})
electrode_xyz = {key: electrode_xyz[key] for key in electrode_name}
electrode_xyz = np.array(list(electrode_xyz.values()))
#电极坐标所属的坐标系
electrode_coord_system = '10-20 spherical model'
#受试者ID
Subject_id = subject_id
#Session ID
Session_id = session_id
#参考电极方案
ref = 'CPZ'
#数据采集开始时间
start_time = 0
meta_struct = {
'subject_id': Subject_id,
'session_id': Session_id,
'ref': ref,
'start_time': start_time
}
eeg_struct = {
'data': Data,
'chn': channel_names,
'sample_rate': sample_rate,
'node_number': node_number,
't': t,
'electrode_name': electrode_name,
'electrode_xyz': electrode_xyz,
'electrode_coord_system': electrode_coord_system,
'meta': meta_struct,
}
fileDir = os.path.join('EEGfiles/',Subject_id,Session_id)
os.makedirs(fileDir,exist_ok=True)
filePath = os.path.join(fileDir,'eeg_data{}.mat'.format(self.file_num))
# 保存到 .mat 文件,顶层变量名为 'eeg'
savemat(filePath, {'eeg': eeg_struct})
print('EEGfile saved at {}'.format(filePath))
self.zmqClient.send_to_all('filePath', filePath)
self.file_num += 1
def read_ch_pos(self,file_path=r'xy_64.xlsx'):
"""
将电极位置信息转换为Dict
参数:
file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'
"""
if getattr(sys, 'frozen', False):
script_dir = sys._MEIPASS
else:
script_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(script_dir, file_path)
df = pd.read_excel(file_path)
# 确保列名正确
if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']):
raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'")
# 创建电极位置字典
ch_pos = {}
for _, row in df.iterrows():
ch_pos[row['channel']] = [row['x'], row['y'], row['z']]
return ch_pos
def _init_filter_cache(self):
"""初始化滤波器系数缓存"""
self._filter_cache = {
'highpass': None,
'lowpass': None,
'notch': None
}
self._cache_valid = False
def _design_filters(self):
"""设计滤波器系数"""
if self._cache_valid:
return
nyquist = self.fs / 2
fs_nyq = self.fs
# 高通滤波 (去除低频漂移)
high = self.lowcut / nyquist
if 0 < high < 1:
self._filter_cache['highpass'] = signal.butter(2, high, btype='high', output='ba')
# 低通滤波 (去除高频噪声)
low = self.highcut / nyquist
if 0 < low < 1:
self._filter_cache['lowpass'] = signal.butter(4, low, btype='low', output='ba')
# 50Hz 陷波滤波 (去除工频干扰)
Q = 30 # 品质因子
self._filter_cache['notch'] = signal.iirnotch(self.notch_freq, Q, fs=fs_nyq)
# 查找CPZ通道索引
electrode_name = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1',
'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6',
'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1','T7',
'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ','PZ',
'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1']
try:
self.ref_chan_idx = electrode_name.index(self.ref_chan_name)
except ValueError:
self.ref_chan_idx = 50 # 默认CZ (对应索引50)
print(f'[警告] 未找到参考电极 {self.ref_chan_name},使用默认值 CZ')
self._cache_valid = True
print(f'[INFO] 预处理已启用 - 高通:{self.lowcut}Hz, 低通:{self.highcut}Hz, 陷波:{self.notch_freq}Hz, 参考:{self.ref_chan_name}(索引:{self.ref_chan_idx})')
def check_data_quality(self, data):
"""
检查数据质量
返回:
list: 发现的问题列表,空列表表示质量正常
"""
issues = []
# 检查幅度
amplitude = np.max(np.abs(data))
if amplitude > 1e6: # 超过 1mV = 1000µV
issues.append(f'幅度异常: {amplitude:.2e} (可能为原始ADC值或单位错误)')
elif amplitude > 1000:
issues.append(f'幅度偏高: {amplitude:.2f}')
# 检查平坦噪声 (通道可能未连接)
if np.std(data) < 0.01:
issues.append('信号过平,可能通道未连接')
# 检查饱和
n_saturated = np.sum(np.abs(data) > 1e8)
if n_saturated > 0:
issues.append(f'检测到 {n_saturated} 个采样点饱和')
return issues
def verify_and_calibrate_unit(self, data):
"""
验证并校准数据单位
SunnyLinker64 的转换公式:
val = raw_adc * 4.5 / gain_value / 8388608 * 1000000 (µV)
但如果硬件实际增益与 gain_value=6 不符,会导致单位错误。
本函数通过检测数据范围来验证和修正单位。
正常EEG信号范围: ±50-100 µV
如果检测到的数据范围是 ±1e6 量级,说明转换可能有问题
参数:
data: 原始数据
返回:
tuple: (校准后的数据, 是否进行了校准)
"""
if self._conversion_verified:
return data, False
amplitude = np.max(np.abs(data))
# 判断数据是否在合理范围内
# 正常EEG: 1 - 1000 µV (考虑某些高幅值情况)
# 异常: > 1e5 µV (可能是ADC原始值未转换或转换系数错误)
if amplitude > 1e6:
print('[警告] 检测到异常大幅值数据可能是ADC原始值或单位转换失败!')
print(f' 当前最大幅度: {amplitude:.2e} µV')
print('[INFO] 尝试自动校准单位转换...')
# SunnyLinker64 的理论转换系数约为 0.0894 µV/LSB
# 如果数据是原始ADC值需要除以这个系数来还原
theoretical_scale = 4.5 / 6 / 8388608 * 1e6 # 理论系数: ~0.0894 µV/LSB
# 计算校准系数
# 假设数据是原始ADC值需要除以 (amplitude / expected_amplitude)
# 正常EEG信号预期幅度约 100 µV
expected_amplitude = 100.0 # µV
if amplitude > expected_amplitude:
# 计算校准系数: 原始值 / 预期值 = 实际值 / 校准后值
self.calibration_scale = expected_amplitude / amplitude
# 应用校准
data = data * self.calibration_scale
print(f'[INFO] 校准完成,应用系数: {self.calibration_scale:.6e}')
print(f' 校准后最大幅度: {np.max(np.abs(data)):.2f} µV')
self._conversion_verified = True
return data, True
elif amplitude < 0.01:
print('[警告] 数据幅度接近零,可能通道未连接或设备异常')
self._conversion_verified = True
return data, False
def preprocess_data(self, data):
"""
EEG信号预处理
参数:
data: ndarray, shape (n_chan, n_samples), 原始EEG数据
返回:
ndarray: 预处理后的EEG数据
"""
if not self.enable_preprocess:
return data
# 确保数据是 float64 类型
data = data.astype(np.float64)
# 设计滤波器
self._design_filters()
# 1. 去除直流分量和低频漂移 (高通滤波)
if self._filter_cache['highpass'] is not None:
b, a = self._filter_cache['highpass']
for ch in range(data.shape[0]):
data[ch, :] = signal.filtfilt(b, a, data[ch, :])
# 2. 50Hz 工频陷波滤波
if self._filter_cache['notch'] is not None:
b, a = self._filter_cache['notch']
for ch in range(data.shape[0]):
data[ch, :] = signal.filtfilt(b, a, data[ch, :])
# 3. 低通滤波 (去除高频噪声)
if self._filter_cache['lowpass'] is not None:
b, a = self._filter_cache['lowpass']
for ch in range(data.shape[0]):
data[ch, :] = signal.filtfilt(b, a, data[ch, :])
# 4. 重参考 (以CPZ为参考)
if self.ref_chan_idx is not None and self.ref_chan_idx < data.shape[0]:
ref_signal = data[self.ref_chan_idx, :]
data = data - ref_signal
return data
def stop(self):
'''
停止运行
@return:
'''
self.zmqServer.stop()
self.Running=False

View File

@@ -0,0 +1,207 @@
import os
import sys
import threading
import time
import numpy as np
import pandas as pd
from SunnyLinker import SunnyLinker64
from zmqServer import zmqServer
from zmqClient import zmqClient
from scipy.io import savemat
class Parser_main(threading.Thread):
def __init__(self):
threading.Thread.__init__(self)
self.Running = True
self.fs = 250 # 采样率
self.energy = 0 # 电量
self.status_code = 0 # 与采集设备通信的状态码0为异常1为正常
self.n_chan = 64
self.dataBuffer = []
self.file_num = 0#保存文件序号
self.subject_id = None #受试者ID
self.session_id = None #Session ID
self.last_print_time = None
def connect(self):
self.thread_data_server = SunnyLinker64('127.0.0.1', 7878, 250, 64,
method='tcp')
self.thread_data_server.toUv = True
self.thread_data_server.start()
self.zmqServer = zmqServer()
self.zmqServer.start()
self.zmqClient = zmqClient('127.0.0.1', 8088)
self.zmqClient.connect()
def run(self):
while self.Running:
# 同步信息
if self.zmqServer.state_mode == 'sync':
self.zmqClient.send_to_all('sync', self.zmqClient.state)
self.zmqServer.state_mode = 'rest'
# 状态异常,报告上位机
if self.status_code != self.thread_data_server.status_code:
self.status_code = self.thread_data_server.status_code
self.zmqClient.send_to_all('status_code', int(self.status_code))
# 返回电量
if self.energy != self.thread_data_server.energy:
self.energy = self.thread_data_server.energy
self.zmqClient.send_to_all('energy', int(self.energy))
# 更新文件序号
if self.subject_id != self.zmqServer.subject_id or self.session_id != self.zmqServer.session_id:
self.subject_id = self.zmqServer.subject_id
self.session_id = self.zmqServer.session_id
self.file_num = 0 #从零开始计数
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
self.thread_data_server.Impedance(True)
self.zmqServer.open_Impedance = -1
elif self.zmqServer.open_Impedance == False:
self.thread_data_server.Impedance(False)
self.zmqServer.open_Impedance = -1
if self.zmqServer.get_Impedance: # 返回阻抗值
if self.thread_data_server.GetDataLenCount() > self.fs:
Impe_data = self.thread_data_server.getData(self.fs)
# 计算阻抗
imps = self.thread_data_server.getImpedance(Impe_data, self.n_chan)
self.zmqClient.send_to_all('impedance', imps.tolist())
else:
pass
if self.thread_data_server.GetDataLenCount() < 50:
time.sleep(0.01)
continue
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
data = self.thread_data_server.getData(50)
data = data[:self.n_chan, :]
if self.zmqServer.mat_generate:
# 检测是否需要重置缓冲区(第二次发送 matGenerate 时清空旧数据)
if self.zmqServer.reset_mat_buffer:
self.dataBuffer = []
self.last_print_time = None
self.zmqServer.reset_mat_buffer = False
print('[INFO] 数据缓冲区已重置,从头开始采集')
self.dataBuffer.append(data)
if len(self.dataBuffer) % 50 == 0:
current_time = time.time()
if self.last_print_time is not None:
elapsed_time = current_time - self.last_print_time
# 2500个点 = 50个数据块 * 50个采样点/数据块
actual_fs = 2500 / elapsed_time
print(f"接收 2500 个采样点耗时: {elapsed_time:.4f} 秒, 折合实际采样率: {actual_fs:.2f} Hz")
else:
print("开始计时...")
self.last_print_time = current_time
print('数据保存进度: {}/{}'.format(len(self.dataBuffer),int(self.zmqServer.save_win*self.fs//50)))
if len(self.dataBuffer) >= int(self.zmqServer.save_win*self.fs//50): #5分钟*60秒*250Hz / 50
self.zmqServer.mat_generate = False
matData = np.hstack(self.dataBuffer[:int(self.zmqServer.save_win*self.fs//50)])
self.dataBuffer = []
self.last_print_time = None # 重置计时器以备下次使用
self.pack2mat(matData,self.subject_id,self.session_id)
def pack2mat(self,data,subject_id,session_id):
#EEG数据
Data = data.T
#通道名称
channel_names = np.array(
['AIN1', 'AIN2', 'AIN3', 'AIN4', 'AIN5', 'AIN6', 'AIN7', 'AIN8', 'AIN9', 'AIN10', 'AIN11', 'AIN12',
'AIN13', 'AIN14', 'AIN15', 'AIN16', 'AIN17', 'AIN18', 'AIN19', 'AIN20', 'AIN21', 'AIN22', 'AIN23',
'AIN24', 'AIN25', 'AIN26', 'AIN27', 'AIN28', 'AIN29', 'AIN30', 'AIN31', 'AIN32', 'AIN33', 'AIN34',
'AIN35', 'AIN36', 'AIN37', 'AIN38', 'AIN39', 'AIN40', 'AIN41', 'AIN42', 'AIN43', 'AIN44', 'AIN45',
'AIN46', 'AIN47', 'AIN48', 'AIN49', 'AIN50', 'AIN51', 'AIN52', 'AIN53', 'AIN54', 'AIN55', 'AIN56',
'AIN57', 'AIN58', 'AIN59', 'AIN60', 'AIN61', 'AIN62', 'AIN63', 'AIN64'], dtype=object)
#采样率
sample_rate = self.fs
#通道数量
node_number = Data.shape[1]
# 时间轴
t = np.linspace(0, self.zmqServer.save_win, Data.shape[0])
t = t.reshape(len(t), 1)
#电极名称
electrode_name = np.array(['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1',
'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6',
'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1','T7',
'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ','PZ',
'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1'],
dtype=object)
#电极三维坐标
electrode_xyz = self.read_ch_pos()
electrode_xyz.update({'A1': [-0.095, 0, -0.005]})
electrode_xyz = {key: electrode_xyz[key] for key in electrode_name}
electrode_xyz = np.array(list(electrode_xyz.values()))
#电极坐标所属的坐标系
electrode_coord_system = '10-20 spherical model'
#受试者ID
Subject_id = subject_id
#Session ID
Session_id = session_id
#参考电极方案
ref = 'CPZ'
#数据采集开始时间
start_time = 0
meta_struct = {
'subject_id': Subject_id,
'session_id': Session_id,
'ref': ref,
'start_time': start_time
}
eeg_struct = {
'data': Data,
'chn': channel_names,
'sample_rate': sample_rate,
'node_number': node_number,
't': t,
'electrode_name': electrode_name,
'electrode_xyz': electrode_xyz,
'electrode_coord_system': electrode_coord_system,
'meta': meta_struct,
}
fileDir = os.path.join('EEGfiles/',Subject_id,Session_id)
os.makedirs(fileDir,exist_ok=True)
filePath = os.path.join(fileDir,'eeg_data{}.mat'.format(self.file_num))
# 保存到 .mat 文件,顶层变量名为 'eeg'
savemat(filePath, {'eeg': eeg_struct})
print('EEGfile saved at {}'.format(filePath))
self.zmqClient.send_to_all('filePath', filePath)
self.file_num += 1
def read_ch_pos(self,file_path=r'xy_64.xlsx'):
"""
将电极位置信息转换为Dict
参数:
file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'
"""
if getattr(sys, 'frozen', False):
script_dir = sys._MEIPASS
else:
script_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(script_dir, file_path)
df = pd.read_excel(file_path)
# 确保列名正确
if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']):
raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'")
# 创建电极位置字典
ch_pos = {}
for _, row in df.iterrows():
ch_pos[row['channel']] = [row['x'], row['y'], row['z']]
return ch_pos
def stop(self):
'''
停止运行
@return:
'''
self.zmqServer.stop()
self.Running=False

View File

@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
"""
EEG Data Quality Check - eeg_data0.mat
===================================
1. Time Domain Signal (Full Duration)
2. Amplitude Spectrum (FFT)
3. Power Spectral Density (Linear Scale)
4. Power Spectral Density (dB Scale)
"""
import numpy as np
import matplotlib.pyplot as plt
import mne
from scipy import signal
from scipy.io import loadmat
def load_and_preprocess(filepath):
"""Load .mat file (custom format) and basic preprocessing."""
mat_data = loadmat(filepath, simplify_cells=True)
eeg = mat_data['eeg']
# Extract data (shape: samples x channels)
data = eeg['data'].T # Transpose to (channels x samples)
sfreq = eeg['sample_rate']
# Get channel names (try multiple possible keys)
if 'chn' in eeg:
ch_names = list(eeg['chn'])
elif 'electrode_name' in eeg:
ch_names = list(eeg['electrode_name'])
else:
n_channels = data.shape[0]
ch_names = [f'Ch{i+1}' for i in range(n_channels)]
# Create MNE Info object
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
raw = mne.io.RawArray(data, info)
raw.filter(l_freq=0.5, h_freq=10, fir_design='firwin', verbose=False)
return raw
def main():
filepath = r"D:\Ivey\Code_New_Proj\brainplot\plot64\eeg_data0511.mat"
output_path = r"D:\Ivey\Code_New_Proj\brainplot\plot64\eeg_quality_check_depression.png"
raw = load_and_preprocess(filepath)
# Print all channel names first
print(f"\nAvailable channels ({len(raw.ch_names)}):")
for i, ch in enumerate(raw.ch_names):
print(f" {i:3d}: {ch}")
select_channel = ['AIN5']
raw.pick(select_channel)
# Use all channels, full duration
ch_names = raw.ch_names
n_channels = len(ch_names)
data = raw.get_data()
sfreq = raw.info['sfreq']
n_samples = data.shape[1]
duration = n_samples / sfreq
print(f"Info: {n_channels} channels, {duration:.1f}s, {sfreq:.0f} Hz")
# Compute frequency domain data
n_fft = 2**int(np.ceil(np.log2(n_samples)))
freqs_fft = np.fft.rfftfreq(n_fft, 1 / sfreq)
fft_vals = np.fft.rfft(data, n=n_fft)
amplitude = np.abs(fft_vals) / n_fft * 2
freqs_psd, psd = signal.welch(data, fs=sfreq, nperseg=4096,
noverlap=2048, scaling='density')
# Frequency mask: 0.5-80 Hz
mask_fft = (freqs_fft >= 0.5) & (freqs_fft <= 80)
mask_psd = (freqs_psd >= 0.5) & (freqs_psd <= 80)
freq_fft = freqs_fft[mask_fft]
freq_psd = freqs_psd[mask_psd]
# Plot: 4 rows x 1 column
fig, axes = plt.subplots(4, 1, figsize=(16, 20))
fig.suptitle(f'EEG Data Quality Check — {", ".join(ch_names)}, '
f'Full Duration: {duration:.1f}s',
fontsize=16, fontweight='bold', y=0.995)
# Colormap for distinct channel
cmap = plt.cm.tab10 if n_channels <= 10 else plt.cm.tab20
colors = [cmap(i) for i in np.linspace(0, 1, n_channels)]
# ---- Row 1: Time Domain Signal ----
ax = axes[0]
offset = 0
step = max(100, np.std(data, axis=1).mean() * 1e6 * 4)
# Downsample for display
ds = max(1, n_samples // (int(duration) * 500))
t = np.arange(0, n_samples, ds) / sfreq
for i in range(n_channels):
sig = data[i, ::ds] * 1e6 + offset
ax.plot(t, sig, linewidth=0.5, alpha=0.9, color=colors[i], label=ch_names[i])
ax.text(t[0] - 0.5, offset, ch_names[i], fontsize=7, va='center', ha='right', color=colors[i])
offset += step
ax.set_xlim(0, duration)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude (μV)')
ax.set_title('1. Time Domain Signal (Full Duration)', fontweight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
# ---- Row 2: Amplitude Spectrum (FFT) ----
ax = axes[1]
amp_data = amplitude[:, mask_fft] * 1e6 # (n_channels, n_freqs)
for i in range(n_channels):
ax.plot(freq_fft, amp_data[i], color=colors[i], linewidth=1.0, alpha=0.85, label=ch_names[i])
ax.axvline(50, color='red', linestyle='--', alpha=0.6, label='50 Hz Mains')
ax.set_xlim(0.5, 30)
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Amplitude (μV)')
ax.set_title('2. Amplitude Spectrum (FFT)', fontweight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
# ---- Row 3: PSD (Linear Scale) ----
ax = axes[2]
psd_data = psd[:, mask_psd] * 1e12 # (n_channels, n_freqs)
for i in range(n_channels):
ax.plot(freq_psd, psd_data[i], color=colors[i], linewidth=1.0, alpha=0.85, label=ch_names[i])
ax.axvline(50, color='red', linestyle='--', alpha=0.6, label='50 Hz Mains')
ax.set_xlim(0.5, 80)
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Power (μV²/Hz)')
ax.set_title('3. Power Spectral Density (Linear Scale)', fontweight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
# ---- Row 4: PSD (dB Scale) ----
ax = axes[3]
for i in range(n_channels):
psd_dbi = 10 * np.log10(psd_data[i] + 1e-20)
ax.plot(freq_psd, psd_dbi, color=colors[i], linewidth=1.0, alpha=0.85, label=ch_names[i])
ax.axvline(50, color='red', linestyle='--', alpha=0.6, label='50 Hz Mains')
ax.set_xlim(0.5, 80)
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Power (dB)')
ax.set_title('4. Power Spectral Density (dB Scale)', fontweight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
plt.tight_layout()
plt.subplots_adjust(top=0.97)
plt.savefig(output_path, dpi=150, bbox_inches='tight',
facecolor='white', edgecolor='none')
print(f"Figure saved to: {output_path}")
plt.show()
if __name__ == "__main__":
main()
from scipy.io import loadmat
import numpy as np
mat = loadmat(r'D:\Ivey\Code_New_Proj\brainplot\plot64\eeg_data0511.mat', simplify_cells=True)
data = mat['eeg']['data'] # (samples, channels)
sfreq = 250
seg1 = data[0:int(10*sfreq), :] # 0-10s
seg2 = data[int(10*sfreq):int(20*sfreq), :] # 10-20s
print('Segment 1 (0-10s) shape:', seg1.shape)
print('Segment 2 (10-20s) shape:', seg2.shape)
print('Are they equal?', np.allclose(seg1, seg2))
print('Max difference:', np.max(np.abs(seg1 - seg2)))
print('Mean difference:', np.mean(np.abs(seg1 - seg2)))
# Check correlation
corr = np.corrcoef(seg1.flatten(), seg2.flatten())[0, 1]
print(f'Correlation: {corr:.4f}')

Binary file not shown.

After

Width:  |  Height:  |  Size: 521 KiB

View File

@@ -0,0 +1,193 @@
from typing import List, Tuple, Union, Optional
class ProtocolFrame:
# 协议常量
FRAME_HEADER = 0xAA
FRAME_TAIL1 = 0x55
FRAME_TAIL2 = 0x55
RESERVED_SIZE = 6
MIN_FRAME_SIZE = 13 # 帧头1 + 功能1 + 长度2 + 预留6 + CRC1 + 包尾2
MAX_DATA_LENGTH = 0xFFFF # 最大数据长度 (2字节能表示的最大值)
@staticmethod
def calculate_crc8(data: bytes) -> bytes:
"""
计算CRC8校验值
Args:
data: 需要计算CRC的数据
Returns:
一个字节的CRC值bytes类型
"""
crc = 0
for byte in data:
crc ^= byte
for _ in range(8):
crc = ((crc << 1) ^ 0x07 if crc & 0x80 else crc << 1) & 0xFF
return bytes([crc])
@classmethod
def pack(cls, function, data: Union[bytes, bytearray, List[int]],
reserved: Optional[Union[bytes, bytearray, List[int]]] = None) -> bytes:
"""
协议打包函数
Args:
function: 功能码 (1字节)
data: 数据块
reserved: 预留字节(6字节可选)
Returns:
打包后的字节数据
"""
# 检查功能码
if function != None:
if not 0 <= function <= 0xFF:
raise ValueError("功能码必须是1字节")
# 转换数据为bytearray
if isinstance(data, list):
data = bytearray(data)
elif isinstance(data, bytes):
data = bytearray(data)
# 检查数据长度
data_length = len(data)
if data_length > cls.MAX_DATA_LENGTH:
raise ValueError(f"数据长度超过最大值 {cls.MAX_DATA_LENGTH}")
# 处理预留字节
if reserved is None:
reserved = bytearray([0] * cls.RESERVED_SIZE)
else:
if isinstance(reserved, list):
reserved = bytearray(reserved)
elif isinstance(reserved, bytes):
reserved = bytearray(reserved)
if len(reserved) != cls.RESERVED_SIZE:
raise ValueError(f"预留字节必须是{cls.RESERVED_SIZE}字节")
# 构建帧
frame = bytearray([cls.FRAME_HEADER]) # 帧头 (1字节)
if function != None:
frame.append(function) # 功能码 (1字节)
data_length+=6
# 数据长度 (2字节大端序)
frame.append((data_length >> 8) & 0xFF) # 高字节
frame.append(data_length & 0xFF) # 低字节
if function != None:
frame.extend(reserved) # 预留字节 (6字节)
frame.extend(data) # 数据块 (变长)
# 计算CRC (从功能码开始到数据块结束)
crc = cls.calculate_crc8(frame[1:]) # 不包含帧头
frame.extend(crc) # CRC校验 (1字节)
# 添加帧尾
frame.extend([cls.FRAME_TAIL1, cls.FRAME_TAIL2]) # 帧尾 (2字节)
return bytes(frame)
@classmethod
def unpack(cls, data: Union[bytes, bytearray]) -> Tuple[int, bytearray, bytearray]:
"""
协议解包函数
Args:
data: 待解析的字节数据
Returns:
(功能码, 数据块, 预留字节)
Raises:
ValueError: 当数据格式不正确时
"""
# 检查数据长度
if len(data) < cls.MIN_FRAME_SIZE:
raise ValueError("数据长度不足")
# 检查帧头
if data[0] != cls.FRAME_HEADER:
raise ValueError("帧头错误")
# 检查帧尾
if data[-2:] != bytes([cls.FRAME_TAIL1, cls.FRAME_TAIL2]):
raise ValueError("帧尾错误")
# 解析基本信息
function = data[1] # 功能码 (1字节)
# 数据长度 (2字节大端序)
data_length = (data[2] << 8) | data[3]
reserved = data[4:10] # 预留字节 (6字节)
# 检查数据长度
expected_length = cls.MIN_FRAME_SIZE + data_length
if len(data) != expected_length:
raise ValueError(f"数据长度不匹配: 期望{expected_length}字节,实际{len(data)}字节")
# 提取数据块
payload = data[10:10 + data_length]
# 验证CRC (从功能码开始到数据块结束)
received_crc = data[-3]
calculated_crc = cls.calculate_crc8(data[1:-3])[0] # 获取字节值
if received_crc != calculated_crc:
raise ValueError(f"CRC校验失败: 期望{calculated_crc:02X},实际{received_crc:02X}")
return function, bytearray(payload), bytearray(reserved)
def print_hex(data: bytes, label: str = ""):
"""打印十六进制数据,并按字节添加空格"""
hex_str = ' '.join([f"{b:02X}" for b in data])
if label:
print(f"{label}: {hex_str}")
else:
print(hex_str)
def print_frame_details(data: bytes):
"""打印帧的详细信息"""
print("帧详细信息:")
print(f"帧头: {data[0]:02X}")
print(f"功能码: {data[1]:02X}")
print(f"数据长度: {data[2]:02X} {data[3]:02X} ({(data[2] << 8) | data[3]}字节)")
print(f"预留字节: {' '.join([f'{b:02X}' for b in data[4:10]])}")
data_length = (data[2] << 8) | data[3]
print(f"数据块: {' '.join([f'{b:02X}' for b in data[10:10 + data_length]])}")
print(f"CRC校验: {data[-3]:02X}")
print(f"帧尾: {data[-2]:02X} {data[-1]:02X}")
# 使用示例
def example_usage():
try:
# 示例1简单数据打包
function_code = 0x01
data = [0x1]
packed_data = ProtocolFrame.pack(function_code, data)
print_hex(packed_data, "示例1 - 完整帧")
print_frame_details(packed_data)
print()
# 示例3解包验证
function, payload, reserved = ProtocolFrame.unpack(packed_data)
print("解包结果:")
print(f"功能码: 0x{function:02X}")
print_hex(payload, "数据块")
print_hex(reserved, "预留字节")
except ValueError as e:
print(f"错误: {e}")
if __name__ == "__main__":
example_usage()

View File

@@ -0,0 +1,17 @@
import time
from eegParser import Parser_main
from RunOnce import is_program_running
if __name__ == "__main__":
if not is_program_running():
parser_ = Parser_main()
parser_.connect()
try:
parser_.start()
while not parser_.zmqServer.IsExitApp:
time.sleep(1)
except KeyboardInterrupt:
parser_.stop()

View File

@@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
"""
PyInstaller 打包验证脚本
用于在没有 EEG 设备的情况下验证打包是否成功
"""
import os
import sys
import subprocess
import shutil
def check_pyinstaller_installed():
"""检查 PyInstaller 是否安装"""
try:
result = subprocess.run(['pyinstaller', '--version'],
capture_output=True, text=True)
print(f"✓ PyInstaller 版本: {result.stdout.strip()}")
return True
except FileNotFoundError:
print("✗ PyInstaller 未安装")
return False
def check_dist_folder():
"""检查 dist 文件夹是否存在"""
base_dir = os.path.dirname(os.path.abspath(__file__))
dist_dir = os.path.join(base_dir, 'dist', 'start_parse')
if os.path.exists(dist_dir):
print(f"✓ dist 文件夹存在: {dist_dir}")
# 检查 exe 文件
exe_path = os.path.join(dist_dir, 'start_parse.exe')
if os.path.exists(exe_path):
size_mb = os.path.getsize(exe_path) / (1024 * 1024)
print(f"✓ 可执行文件存在: start_parse.exe ({size_mb:.1f} MB)")
else:
print("✗ 可执行文件不存在")
return False
# 检查资源文件
xlsx_path = os.path.join(dist_dir, 'xy_64.xlsx')
if os.path.exists(xlsx_path):
print(f"✓ 资源文件存在: xy_64.xlsx")
else:
print("✗ 资源文件 xy_64.xlsx 不存在")
return True
else:
print(f"✗ dist 文件夹不存在,请先运行打包")
return False
def check_dependencies():
"""检查关键依赖是否在打包中"""
base_dir = os.path.dirname(os.path.abspath(__file__))
dist_dir = os.path.join(base_dir, 'dist', 'start_parse')
if not os.path.exists(dist_dir):
return False
# 检查关键 DLL 文件
critical_dlls = [
# zmq 依赖
'libzmq.pyd',
# numpy 依赖
'numpy.core._multiarray_umath.cp310-win_amd64.pyd',
# scipy 依赖
'scipy.special._ufuncs.cp310-win_amd64.pyd',
]
print("\n检查关键依赖文件:")
found_count = 0
for dll in critical_dlls:
found = False
for root, dirs, files in os.walk(dist_dir):
if dll in files:
found = True
break
status = "" if found else ""
print(f" {status} {dll}")
if found:
found_count += 1
return found_count >= len(critical_dlls) // 2
def test_imports():
"""测试关键模块是否可以导入"""
print("\n测试模块导入:")
modules = ['zmq', 'serial', 'numpy', 'pandas', 'scipy']
success = True
for mod in modules:
try:
__import__(mod)
print(f"{mod}")
except ImportError as e:
print(f"{mod}: {e}")
success = False
return success
def main():
print("=" * 60)
print("PyInstaller 打包验证")
print("=" * 60)
checks = [
("1. 检查 PyInstaller 安装", check_pyinstaller_installed),
("2. 检查 dist 文件夹", check_dist_folder),
("3. 检查依赖文件", check_dependencies),
("4. 测试模块导入", test_imports),
]
results = []
for name, check_func in checks:
print(f"\n{name}")
print("-" * 40)
results.append(check_func())
print("\n" + "=" * 60)
print("验证结果汇总:")
print("=" * 60)
all_passed = all(results)
if all_passed:
print("✓ 所有检查通过!打包成功。")
print("\n下一步:")
print(" 1. 将 dist/start_parse 文件夹复制到目标电脑")
print(" 2. 连接 EEG 设备并运行 start_parse.exe")
print(" 3. 观察控制台输出是否正常")
else:
print("✗ 部分检查未通过,请查看上述详细信息")
return all_passed
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,57 @@
import threading
import time
import json
import zmq
class zmqClient:
def __init__(self, host, port):
self.host = host
self.port = port
self.client_socket = None
self.running = False
# 记录客户端连接前的状态
self.state = {
'status_code': None,
'energy': None
}
def connect(self):
# 创建 ZeroMQ 上下文
self.context = zmq.Context()
# 创建 REQ 套接字(请求端)
self.client_socket = self.context.socket(zmq.DEALER)
# client_id = b'client1'
# self.client_socket.setsockopt(zmq.IDENTITY,client_id)
self.client_socket.connect(f"tcp://{self.host}:{self.port}") # 连接到服务器
self.running = True
def send_to_all(self, method,params):
if method in self.state.keys():
self.state[method] = params
try:
if self.running and self.client_socket != None:
msg = {'method': method, 'params': params}
# 发送响应
# print(msg)
self.client_socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
else:
if method in self.state.keys():
self.state[method] = params
except ConnectionResetError:
print("Connection lost.")
self.running = False
except Exception as e:
print(f"An error occurred: {e}")
def close_connection(self):
self.running = False
self.client_socket.close()
self.context.term()
print("Client closed explicitly.")
# 使用TCP客户端
if __name__ == "__main__":
client = zmqClient('127.0.0.1', 8099)
client.connect()
# client.close_connection()

View File

@@ -0,0 +1,119 @@
import numpy as np
import zmq
import threading
import json
from SunnyLinker import SunnyLinker64
class zmqServer(threading.Thread):
def __init__(self, host='0.0.0.0', port=8099):
threading.Thread.__init__(self)
self.host = host
self.port = port
self.running = False
self.get_Impedance = False # 是否返回阻抗值
self.open_Impedance = None # 是否开启阻抗检测功能
self.StartDecode = False # false 停止解码true=开始解码
self.StartTrain = False # False未进入训练状态True处于训练状态
self.state_mode = None # 'train'为训练状态rest'为休息状态,'test'为测试状态
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
self.IsExitApp = False # 当socket收到2的时候就置为True代表遥退出系统了。
self.getReport = False # 获取训练报告内容
self.mat_generate = False # 保存mat文件True开始False暂停
self.reset_mat_buffer = False # 重置缓冲区标志True表示下次开始采集需要清空旧数据
self.subject_id = None #受试者ID
self.session_id = None #Session ID
self.save_win = 0 #保存数据时长
self.daemon = True
# 创建 ZeroMQ 上下文
self.context = zmq.Context()
# 创建 REP 套接字(响应端)
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(f"tcp://{self.host}:{self.port}") # 绑定到端口 8099
self.targetFreqs = []
self.changeTarget = False # 更换目标频率
self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类已在Decoder实例化
self.labels = [0x01, 0x02,0x03]
self.decoder_switch = False #更换解码器
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
def run(self):
self.running = True
print(f"Server is running on {self.host}:{self.port}")
try:
while self.running:
# 等待客户端请求
_,_,message = self.socket.recv_multipart()
message = json.loads(message.decode('utf-8'))
print(f"Received request: {message}")
# 处理请求
method = message.get("method")
params = message.get("params")
if method == "sync":
self.state_mode = 'sync'
if method == "targetFreqs":
if not isinstance(params,list):
print('targetFreqs must be a list')
continue
if params != self.targetFreqs:
self.targetFreqs = params
self.changeTarget = True
if method == "decoderClass":
if not isinstance(params,str):
print('decoderClass must be a str')
continue
# if params != self.decoder_class:
self.decoder_class = params
self.decoder_switch = True
if method == "getReport":
self.getReport = True
if method == "train":#训练状态
self.state_mode = 'train'
self.StartTrain = True
self.currentLabel = params # 当前刺激端的训练标签
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
elif method == "predict":#预测状态
self.state_mode = 'predict'
if params == 1: #开始解码
self.StartDecode = True
self.sunnyLinker.push_trigger(0x63)
elif params == 2: #停止解码
self.IsExitApp = True
self.running = False
elif method == "rest": #休息状态
self.state_mode = 'rest'
elif method == "impedance":
if params == 1:
self.open_Impedance = True # 开启阻抗
self.get_Impedance = True # 返回阻抗
elif params == 2:
self.open_Impedance = False # 关闭阻抗
self.get_Impedance = False # 停止返回阻抗
elif method == "matGenerate":
self.subject_id = str(params['subject_id'])
self.session_id = str(params['session_id'])
self.save_win = int(params['time'])
self.mat_generate = True
self.reset_mat_buffer = True # 每次发送 matGenerate 都重置缓冲区
elif method == "stop": # 停止
self.mat_generate = False
except Exception as e:
print(f"An socket error occurred: {e}")
finally:
self.running = False
# 关闭套接字和上下文
self.socket.close()
self.context.term()
print("Server socket and context closed.")
def stop(self):
"""显式关闭服务器"""
self.running = False
self.socket.close()
self.context.term()
print("Server closed explicitly.")
if __name__ == '__main__':
server = zmqServer()
server.start()