original push

This commit is contained in:
Ivey Song
2026-06-01 13:18:36 +08:00
commit 8426770db6
46 changed files with 341750 additions and 0 deletions

56
algorithm_V1/.gitignore vendored Normal file
View File

@@ -0,0 +1,56 @@
# Byte-compiled / optimized / DLL files
__pycache__/
# Distribution / packaging
build/
dist/
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# data format
*.dat
*.csv
*.edf
*.event
*.edf.event
*.zip
*.xlsx
*.mat
*.json
*.7z
# PyCharm
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
# https://github.com/github/gitignore/blob/main/Python.gitignore
.idea/
# Logs
*.log
# Other common ignores
node_modules/
dist/
tmp/
temp/
# Project-specific ignores
# Ignore all directories in the root
# merge64ch_0127/
/P300_speller/braindecode/
/P300_speller/data/
/P300_speller/pyRiemann/
/P300_speller/README/
/merge64ch_new/
/merge64ch_tianjinZMQdebug/

View File

@@ -0,0 +1,250 @@
# -*- coding: utf-8 -*-
"""
bdf_analyzer.py
Analyze .bdf files - print data amplitude range and mean values.
Supports single file or batch processing of all .bdf files in a directory.
"""
import os
import glob
import numpy as np
import mne
import scipy.signal as signal
def analyze_bdf(filepath: str, unit: str = "uV") -> dict:
"""
Analyze a single .bdf file and compute statistics.
Parameters
----------
filepath : str
Path to .bdf file
unit : str, optional
Display unit (default: uV for microvolts)
Returns
-------
dict
Dictionary containing statistics
"""
print("=" * 60)
print(f"File: {os.path.basename(filepath)}")
print("=" * 60)
try:
# Read BDF file
raw = mne.io.read_raw_bdf(filepath, preload=True, verbose=False)
# Get data (n_channels, n_times) in V
data = raw.get_data()
n_channels, n_times = data.shape
sfreq = raw.info["sfreq"]
# Convert to microvolts (uV)
data_uv = data * 1e6
# Raw data statistics (V)
raw_all = data.flatten()
raw_min = float(np.min(raw_all))
raw_max = float(np.max(raw_all))
raw_mean = float(np.mean(raw_all))
raw_std = float(np.std(raw_all))
# Overall statistics
all_values = data_uv.flatten()
min_val = np.min(all_values)
max_val = np.max(all_values)
mean_val = np.mean(all_values)
std_val = np.std(all_values)
print(f"Sampling rate: {sfreq:.2f} Hz")
print(f"Channels: {n_channels}")
print(f"Samples: {n_times:,}")
print(f"Duration: {n_times / sfreq:.2f} sec")
print("-" * 40)
print(f"[RAW - V]")
print(f"Amplitude range: [{raw_min:.6f}, {raw_max:.6f}] V")
print(f"Mean value: {raw_mean:.6f} V")
print(f"Std deviation: {raw_std:.6f} V")
print(f"[RAW - uV]")
print(f"Amplitude range: [{min_val:.4f}, {max_val:.4f}] uV")
print(f"Mean value: {mean_val:.4f} uV")
print(f"Std deviation: {std_val:.4f} uV")
print("-" * 40)
# Per-channel statistics
print("\nPer-channel statistics:")
print(f"{'Channel':<15} {'Min (uV)':<15} {'Max (uV)':<15} {'Mean (uV)':<15} {'PSD Peak (Hz)':<15}")
print("-" * 75)
channel_stats = []
for i, ch_name in enumerate(raw.ch_names):
ch_data = data_uv[i, :]
ch_min = np.min(ch_data)
ch_max = np.max(ch_data)
ch_mean = np.mean(ch_data)
# PSD peak frequency
nperseg = min(1024, n_times)
freqs, pxx = signal.welch(ch_data, fs=sfreq, nperseg=nperseg)
peak_idx = np.argmax(pxx)
peak_freq = freqs[peak_idx]
print(f"{ch_name:<15} {ch_min:<15.4f} {ch_max:<15.4f} {ch_mean:<15.4f} {peak_freq:<15.2f}")
channel_stats.append({
"name": ch_name,
"min": ch_min,
"max": ch_max,
"mean": ch_mean,
"psd_peak_hz": peak_freq
})
print("=" * 60)
print()
return {
"filepath": filepath,
"sfreq": sfreq,
"n_channels": n_channels,
"n_times": n_times,
"duration": n_times / sfreq,
"raw_min": raw_min,
"raw_max": raw_max,
"raw_mean": raw_mean,
"raw_std": raw_std,
"min": min_val,
"max": max_val,
"mean": mean_val,
"std": std_val,
"channels": channel_stats
}
except Exception as e:
print(f"[ERROR] Failed to read file: {e}")
return None
def analyze_directory(dir_path: str) -> list:
"""
Analyze all .bdf files in a directory.
Parameters
----------
dir_path : str
Directory path
Returns
-------
list
List of analysis results for all files
"""
# Find all .bdf files
bdf_files = sorted(glob.glob(os.path.join(dir_path, "*.bdf")))
if not bdf_files:
print(f"[WARNING] No .bdf files found in: {dir_path}")
return []
print(f"Found {len(bdf_files)} .bdf file(s)\n")
results = []
for filepath in bdf_files:
result = analyze_bdf(filepath)
if result:
results.append(result)
# Summary statistics
if results:
print("\n" + "=" * 60)
print("Summary")
print("=" * 60)
all_means = [r["mean"] for r in results]
all_mins = [r["min"] for r in results]
all_maxs = [r["max"] for r in results]
print(f"File count: {len(results)}")
print(f"[RAW - V] Overall range: [{min(r['raw_min'] for r in results):.6f}, {max(r['raw_max'] for r in results):.6f}] V")
print(f"[RAW - V] Avg mean: {np.mean([r['raw_mean'] for r in results]):.6f} V")
print(f"[RAW - uV] Overall range: [{min(all_mins):.4f}, {max(all_maxs):.4f}] uV")
print(f"[RAW - uV] Avg mean: {np.mean(all_means):.4f} uV")
print(f"Max value file: {results[np.argmax(all_maxs)]['filepath']}")
print(f"Min value file: {results[np.argmin(all_mins)]['filepath']}")
# Per-channel mean summary across all files
n_channels = len(results[0]["channels"])
ch_names = [results[0]["channels"][i]["name"] for i in range(n_channels)]
ch_mean_over_files = []
for ch_idx in range(n_channels):
ch_means = [results[f_idx]["channels"][ch_idx]["mean"] for f_idx in range(len(results))]
ch_mean_over_files.append(np.mean(ch_means))
ch_peak_over_files = []
for ch_idx in range(n_channels):
ch_peaks = [results[f_idx]["channels"][ch_idx]["psd_peak_hz"] for f_idx in range(len(results))]
ch_peak_over_files.append(np.mean(ch_peaks))
print("\nPer-channel mean across all files:")
print(f"{'Channel':<15} {'Mean (uV)':<15} {'PSD Peak (Hz)':<15}")
print("-" * 45)
for ch_name, ch_mean, ch_peak in zip(ch_names, ch_mean_over_files, ch_peak_over_files):
print(f"{ch_name:<15} {ch_mean:<15.4f} {ch_peak:<15.2f}")
return results
def main():
"""Main function with CLI support."""
import argparse
# Default analysis directory
default_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "raw_data")
parser = argparse.ArgumentParser(
description="Analyze .bdf files - print amplitude range and mean values",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=f"""
Examples:
python bdf_analyzer.py # Analyze all .bdf in raw_data/
python bdf_analyzer.py data/test.bdf # Analyze single file
python bdf_analyzer.py data/ # Analyze all .bdf in directory
python bdf_analyzer.py . -u mV # Current dir, unit mV
"""
)
parser.add_argument(
"path",
nargs="?",
default=default_dir,
help="Path to BDF file or directory containing BDF files (default: raw_data/)"
)
parser.add_argument(
"-u", "--unit",
choices=["uV", "mV", "V"],
default="uV",
help="Display unit (default: uV)"
)
args = parser.parse_args()
filepath = args.path
# Determine file or directory mode
if os.path.isfile(filepath):
# Single file mode
result = analyze_bdf(filepath, unit=args.unit)
if result:
print("Analysis complete!")
elif os.path.isdir(filepath):
# Directory mode
results = analyze_directory(filepath)
if results:
print("\nBatch analysis complete!")
else:
print("No analyzable files found")
else:
print(f"[ERROR] File does not exist: {filepath}")
if __name__ == "__main__":
main()

203
algorithm_V1/bdf_to_mat.py Normal file
View File

@@ -0,0 +1,203 @@
# -*- coding: utf-8 -*-
"""
Convert BDF file to MAT format.
This script converts a BDF (Biosemi Data Format) EEG file to .mat format,
matching the structure of eeg_data.mat.
Structure of eeg_data.mat:
- data: (n_samples, n_channels) float64
- chn: (1, n_channels) object - channel names
- sample_rate: (1, 1) int64
- node_number: (1, 1) int64
- t: (n_samples, 1) float64 - time vector in seconds
- electrode_name: (1, n_channels) object - electrode names (10-20 system)
- electrode_xyz: (n_channels, 3) float64 - electrode 3D coordinates
- electrode_coord_system: (1,) <U21
- meta: (1, 1) structured - metadata
"""
import numpy as np
import scipy.io
import mne
from datetime import datetime
def get_standard_electrode_coords():
"""
Standard 10-20 system electrode coordinates.
Returns a dictionary mapping electrode names to x, y, z coordinates.
Coordinates are approximate spherical projections.
"""
coords = {
'FP1': (-0.0293, 0.0903, -0.0033),
'FP2': (0.0293, 0.0903, -0.0033),
'FPZ': (0.0, 0.0903, -0.0033),
'AF7': (-0.0658, 0.0734, -0.0224),
'AF3': (-0.0350, 0.0812, -0.0183),
'AF4': (0.0350, 0.0812, -0.0183),
'AF8': (0.0658, 0.0734, -0.0224),
'F7': (-0.0815, 0.0467, -0.0336),
'F5': (-0.0667, 0.0503, -0.0351),
'F3': (-0.0489, 0.0560, -0.0370),
'F1': (-0.0254, 0.0584, -0.0384),
'FZ': (0.0, 0.0584, -0.0384),
'F2': (0.0254, 0.0584, -0.0384),
'F4': (0.0489, 0.0560, -0.0370),
'F6': (0.0667, 0.0503, -0.0351),
'F8': (0.0815, 0.0467, -0.0336),
'FT7': (-0.0880, 0.0229, -0.0397),
'FC5': (-0.0699, 0.0317, -0.0402),
'FC3': (-0.0514, 0.0362, -0.0411),
'FC1': (-0.0268, 0.0383, -0.0419),
'FCZ': (0.0, 0.0383, -0.0419),
'FC2': (0.0268, 0.0383, -0.0419),
'FC4': (0.0514, 0.0362, -0.0411),
'FC6': (0.0699, 0.0317, -0.0402),
'FT8': (0.0880, 0.0229, -0.0397),
'T7': (-0.0958, 0.0, -0.0411),
'T8': (0.0958, 0.0, -0.0411),
'C5': (-0.0739, 0.0, -0.0425),
'C3': (-0.0544, 0.0, -0.0436),
'C1': (-0.0283, 0.0, -0.0444),
'CZ': (0.0, 0.0, -0.0444),
'C2': (0.0283, 0.0, -0.0444),
'C4': (0.0544, 0.0, -0.0436),
'C6': (0.0739, 0.0, -0.0425),
'TP7': (-0.0880, -0.0229, -0.0397),
'CP5': (-0.0699, -0.0317, -0.0402),
'CP3': (-0.0514, -0.0362, -0.0411),
'CP1': (-0.0268, -0.0383, -0.0419),
'CPZ': (0.0, -0.0383, -0.0419),
'CP2': (0.0268, -0.0383, -0.0419),
'CP4': (0.0514, -0.0362, -0.0411),
'CP6': (0.0699, -0.0317, -0.0402),
'TP8': (0.0880, -0.0229, -0.0397),
'P7': (-0.0815, -0.0467, -0.0336),
'P5': (-0.0667, -0.0503, -0.0351),
'P3': (-0.0489, -0.0560, -0.0370),
'P1': (-0.0254, -0.0584, -0.0384),
'PZ': (0.0, -0.0584, -0.0384),
'P2': (0.0254, -0.0584, -0.0384),
'P4': (0.0489, -0.0560, -0.0370),
'P6': (0.0667, -0.0503, -0.0351),
'P8': (0.0815, -0.0467, -0.0336),
'PO7': (-0.0658, -0.0734, -0.0224),
'PO5': (-0.0503, -0.0744, -0.0258),
'PO3': (-0.0350, -0.0812, -0.0183),
'POZ': (0.0, -0.0829, -0.0172),
'PO4': (0.0350, -0.0812, -0.0183),
'PO6': (0.0503, -0.0744, -0.0258),
'PO8': (0.0658, -0.0734, -0.0224),
'O1': (-0.0293, -0.0903, -0.0033),
'OZ': (0.0, -0.0903, -0.0033),
'O2': (0.0293, -0.0903, -0.0033),
'CB1': (-0.0618, -0.0380, -0.0387),
'CB2': (0.0618, -0.0380, -0.0387),
'A1': (-0.0958, 0.0, 0.0),
'A2': (0.0958, 0.0, 0.0),
}
return coords
def bdf_to_mat(bdf_path, output_path, subject_id='unknown', session_id='unknown'):
"""
Convert BDF file to MAT format matching eeg_data.mat structure.
Parameters
----------
bdf_path : str
Path to the input BDF file.
output_path : str
Path to the output MAT file.
subject_id : str, optional
Subject identifier. Default is 'unknown'.
session_id : str, optional
Session identifier. Default is 'unknown'.
"""
print(f'Loading BDF file: {bdf_path}')
raw = mne.io.read_raw_bdf(bdf_path, preload=True, verbose=False)
# Get basic info
ch_names = raw.ch_names
n_channels = len(ch_names)
sfreq = int(raw.info['sfreq'])
data = raw.get_data()
# BDF data shape: (n_channels, n_samples)
# Convert to eeg_data.mat format: (n_samples, n_channels)
data = data.T
n_samples = data.shape[0]
# Create time vector (in seconds)
t = np.arange(n_samples) / sfreq
t = t.reshape(-1, 1)
# Create channel names array (matching eeg_data.mat structure)
chn = np.array([[name] for name in ch_names], dtype=object)
# Create electrode names (same as channel names for BDF)
electrode_name = np.array([[name] for name in ch_names], dtype=object)
# Get electrode coordinates
standard_coords = get_standard_electrode_coords()
electrode_xyz = np.zeros((n_channels, 3))
for i, name in enumerate(ch_names):
if name in standard_coords:
electrode_xyz[i] = standard_coords[name]
else:
print(f'Warning: No standard coordinate for electrode {name}')
# Create metadata structure
start_time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
meta = np.array([(subject_id, session_id, 'CMS/DRL', start_time_str)],
dtype=[('subject_id', 'O'), ('session_id', 'O'),
('ref', 'O'), ('start_time', 'O')])
# Create the EEG structure (matching eeg_data.mat format)
eeg_struct = np.array([(data, chn, [[sfreq]], [[n_channels]],
t, electrode_name, electrode_xyz,
'buzsaki', meta)],
dtype=[('data', 'O'), ('chn', 'O'),
('sample_rate', 'O'), ('node_number', 'O'),
('t', 'O'), ('electrode_name', 'O'),
('electrode_xyz', 'O'), ('electrode_coord_system', 'O'),
('meta', 'O')])
# Save to MAT file
print(f'Saving to: {output_path}')
scipy.io.savemat(output_path, {'eeg': eeg_struct}, do_compression=True)
print(f'\nConversion complete!')
print(f' Channels: {n_channels}')
print(f' Samples: {n_samples}')
print(f' Duration: {n_samples / sfreq:.2f} seconds')
print(f' Sample rate: {sfreq} Hz')
print(f' Data shape: {data.shape}')
def main():
# File paths
bdf_path = r'D:\Ivey\Code_New_Proj\Debug_Depression\algorithm_version_0521_v0\0515-18.bdf'
output_path = r'D:\Ivey\Code_New_Proj\Debug_Depression\algorithm_version_0521_v0\0515-18.mat'
# Convert
bdf_to_mat(bdf_path, output_path, subject_id='lvpeng', session_id='01')
# Verify the output
print('\n=== Verification ===')
mat_data = scipy.io.loadmat(output_path)
eeg = mat_data['eeg'][0, 0]
print(f'Output file keys: {list(mat_data.keys())}')
print(f'eeg.data shape: {eeg["data"].shape}')
print(f'eeg.chn shape: {eeg["chn"].shape}')
print(f'eeg.sample_rate: {eeg["sample_rate"][0, 0]}')
print(f'eeg.t shape: {eeg["t"].shape}')
print(f'eeg.electrode_name: {eeg["electrode_name"].shape}')
print(f'eeg.electrode_xyz shape: {eeg["electrode_xyz"].shape}')
print(f'eeg.electrode_coord_system: {eeg["electrode_coord_system"][0]}')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,96 @@
# -*- mode: python ; coding: utf-8 -*-
import sys
import os
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
# ========================================================
# 1. 工程配置区 (Project Config)
# ========================================================
block_cipher = None
ENTRY_POINT = 'runDecoder.py'
APP_NAME = 'Depression_Decoder'
# ========================================================
# 2. 依赖分析 (Dependency Analysis)
# ========================================================
# 收集 mne, sklearn, scipy 可能遗漏的隐藏导入
hidden_imports = [
'infer_pth', # 你的动态导入模块
'sklearn.utils._cython_blas',
'sklearn.neighbors.typedefs',
'sklearn.neighbors.quad_tree',
'sklearn.tree',
'sklearn.tree._utils',
]
# 自动收集 mne 的子模块
hidden_imports += collect_submodules('mne')
# 收集 torch 相关的隐式导入(虽然 PyInstaller 通常能处理,但显式更安全)
hidden_imports += ['torch', 'torchvision']
# ========================================================
# 3. 资源锚定 (Data Anchoring)
# ========================================================
# Analysis 中的 datas 用于将文件嵌入到内部(运行时在临时目录或 _internal
# 这里我们留空,改为在 COLLECT 阶段通过 Tree 显式复制到 EXE 旁,
# 这样生成的文件夹里能直接看到 model 和 raw_data
datas = []
# 收集 mne 的数据文件(如果需要默认配置)
datas += collect_data_files('mne')
# ========================================================
# 4. 构建流程 (Build Process)
# ========================================================
a = Analysis(
[ENTRY_POINT],
pathex=[],
binaries=[],
datas=datas,
hiddenimports=hidden_imports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython'], # 排除 GUI 和交互式库减小体积
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name=APP_NAME,
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
# ========================================================
# 5. 打包模式: OneDir (单文件夹) + 资源旁路
# ========================================================
# 使用 Tree 将文件夹原样复制到 dist/APP_NAME/ 下
# 格式: Tree('源路径', prefix='目标子目录')
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=False,
upx_exclude=[],
name=APP_NAME,
)

View File

@@ -0,0 +1,77 @@
import os
import shutil
import subprocess
import sys
def main():
# 1. 定义路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DIST_DIR = os.path.join(BASE_DIR, 'dist')
APP_NAME = 'Depression_Decoder'
TARGET_DIR = os.path.join(DIST_DIR, APP_NAME)
MODEL_SRC = os.path.join(BASE_DIR, 'model')
RAW_DATA_SRC = os.path.join(BASE_DIR, 'raw_data')
MODEL_DST = os.path.join(TARGET_DIR, 'model')
RAW_DATA_DST = os.path.join(TARGET_DIR, 'raw_data')
# 2. 清理旧构建
print("[1/3] Cleaning up old builds...")
if os.path.exists(DIST_DIR):
try:
shutil.rmtree(DIST_DIR)
print(" Cleaned dist/")
except Exception as e:
print(f" Warning: Could not clean dist/: {e}")
BUILD_DIR = os.path.join(BASE_DIR, 'build')
if os.path.exists(BUILD_DIR):
try:
shutil.rmtree(BUILD_DIR)
print(" Cleaned build/")
except Exception as e:
print(f" Warning: Could not clean build/: {e}")
# 3. 运行 PyInstaller
print("[2/3] Running PyInstaller...")
# 注意:我们这里不传 --noupx因为已经在 spec 文件里把 upx=False 写死了
cmd = [
"pyinstaller",
"build_algorithm.spec",
"--clean"
]
try:
subprocess.check_call(cmd, shell=True)
except subprocess.CalledProcessError:
print("Error: PyInstaller failed.")
sys.exit(1)
# 4. 复制外部资源文件夹
print("[3/3] Copying external resources...")
# 复制 model 文件夹
if os.path.exists(MODEL_SRC):
if os.path.exists(MODEL_DST):
shutil.rmtree(MODEL_DST)
shutil.copytree(MODEL_SRC, MODEL_DST)
print(f" Copied: model -> {MODEL_DST}")
else:
print(f" Warning: Source model dir not found at {MODEL_SRC}")
# 复制 raw_data 文件夹
if os.path.exists(RAW_DATA_SRC):
if os.path.exists(RAW_DATA_DST):
shutil.rmtree(RAW_DATA_DST)
shutil.copytree(RAW_DATA_SRC, RAW_DATA_DST)
print(f" Copied: raw_data -> {RAW_DATA_DST}")
else:
print(f" Warning: Source raw_data dir not found at {RAW_DATA_SRC}")
print("\n" + "="*50)
print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}")
print("="*50)
if __name__ == "__main__":
main()

561
algorithm_V1/infer_pth.py Normal file
View File

@@ -0,0 +1,561 @@
# -*- coding: utf-8 -*-
"""
infer_pth.py
用途:
- 从一个文件夹中自动读取第一个 .mat EEG 文件64通道或32通道
- 若为64通道则按 idx64_to_32 映射选出32通道
- 提取切片特征DE + PSD(var近似)不含Asym
- 加载你训练好的 .pth 模型FusionNet结构
- 输出该受试者的 HC / MDD 判断结果
运行方式(命令行):
python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth"
也可在其他py里import
from infer_pth_from64_to32 import predict_hc_mdd
res = predict_hc_mdd(eeg_dir, model_path)
print(res)
"""
from __future__ import annotations
import os
import argparse
import numpy as np
import scipy.io
import scipy.signal as signal
import torch
import torch.nn as nn
import torch.nn.functional as F
# =========================================================
# 0) 配置区(按需改这里)
# =========================================================
# 采样率(必须与训练时一致)
SAMPLING_RATE = 250
# 滑窗参数(必须与训练时一致)
WINDOW_SIZE = 500
STRIDE = 250
# 频段(必须与训练时一致)
BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
BANDS = {
"Delta": (1, 4),
"Theta": (4, 8),
"Alpha": (8, 13),
"Beta": (13, 30),
"Gamma": (30, 50),
}
# 是否使用扩展特征DE+PSD
USE_EXTENDED_FEATURES = True
# 数值稳定项
EPS = 1e-12
# 设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 通道映射
IDX64_TO_32 = [
23, # C5
47, # O1
39, # TP7
6, # FPZ
2, # PO6
21, # P4
35, # AF7
57, # AF3
1, # FP2
37, # T7
63, # F1
36, # A1
18, # FC4
31, # FC5
14, # FC2
48, # T8
60, # P2
41, # AF8
11, # CP1
0, # FP1
55, # PO7
59, # C1
22, # F5
10, # CP2
16, # C3
61, # P1
27, # CP5
17, # C4
26, # CP6
62, # F2
3, # POZ
13, # PO5
]
# 推理阈值如果模型checkpoint里有 subject_threshold会优先用它否则用这个
DEFAULT_SUBJECT_THRESHOLD = 0.5
# =========================================================
# 1) 模型结构
# =========================================================
class SEBlock(nn.Module):
def __init__(self, channels: int, reduction: int = 4) -> None:
super().__init__()
hidden = max(1, channels // reduction)
self.fc1 = nn.Linear(channels, hidden)
self.fc2 = nn.Linear(hidden, channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
se = F.relu(self.fc1(x))
se = torch.sigmoid(self.fc2(se))
return x * se
class ResidualBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None:
super().__init__()
self.fc1 = nn.Linear(in_features, out_features)
self.bn1 = nn.BatchNorm1d(out_features)
self.fc2 = nn.Linear(out_features, out_features)
self.bn2 = nn.BatchNorm1d(out_features)
self.dropout = nn.Dropout(dropout)
self.shortcut = nn.Identity()
if in_features != out_features:
self.shortcut = nn.Sequential(
nn.Linear(in_features, out_features),
nn.BatchNorm1d(out_features),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
out = F.relu(self.bn1(self.fc1(x)))
out = self.dropout(out)
out = self.bn2(self.fc2(out))
out = F.relu(out + identity)
return out
class FusionNet(nn.Module):
def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None:
super().__init__()
self.input_norm = nn.BatchNorm1d(num_eeg_features)
self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4)
self.block2 = ResidualBlock(512, 256, dropout=0.3)
self.block3 = ResidualBlock(256, 128, dropout=0.2)
self.attention = SEBlock(128, reduction=4)
self.final_fc = nn.Sequential(
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.2),
)
self.cls_head = nn.Linear(64, num_classes)
# 训练时有回归头也没关系推理只用cls
self.reg_head = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(32, num_scales),
)
self._init_weights()
def _init_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor):
x = self.input_norm(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.attention(x)
features = self.final_fc(x)
cls_out = self.cls_head(features)
reg_out = self.reg_head(features)
return cls_out, reg_out
# =========================================================
# 2) mat读取 + 通道裁剪
# =========================================================
def _find_first_mat_file(folder: str) -> str:
if not os.path.isdir(folder):
raise RuntimeError(f"eeg_dir 不是文件夹: {folder}")
mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")])
if not mats:
raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}")
return os.path.join(folder, mats[0])
import numpy as np
import scipy.io
def _unwrap_singleton(x):
"""
把 (1,1) / (1,) 这种包裹层一直剥掉,直到不是 singleton。
也处理 object array 的情况。
"""
while True:
if isinstance(x, np.ndarray):
if x.dtype == object and x.size == 1:
x = x.item()
continue
if x.size == 1 and x.ndim >= 1:
# 例如 (1,1) 或 (1,) 的数值/对象数组
try:
x = x.reshape(-1)[0]
continue
except Exception:
pass
break
return x
def _try_get_struct_field(v, field_name="data"):
"""
尝试从以下几种结构中提取字段:
1) scipy 读出的 mat_struct有 _fieldnames
2) numpy structured/record arraydtype.names
"""
# case 1: mat_struct推荐 loadmat(..., struct_as_record=False, squeeze_me=True)
if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])):
return getattr(v, field_name)
# case 2: structured array
if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names):
# 常见是 v[field] 仍然是 ndarray / object需要 unwrap
try:
return v[field_name]
except Exception:
return None
return None
def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray:
"""
读取 .mat 中 EEG 数据,支持:
- 直接二维矩阵 (T,C) 或 (C,T)
- struct 里有字段 data
返回统一为 float32 的 (T, C)
"""
# 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
candidates = []
for k, v in mat.items():
if k.startswith("__"):
continue
# --- 1) 直接二维数值矩阵 ---
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
candidates.append((k, v))
continue
# --- 2) struct/record优先提取 data 字段 ---
data_field = _try_get_struct_field(v, "data")
if data_field is not None:
data_field = _unwrap_singleton(data_field)
# data_field 可能仍然被 object 包一层
if isinstance(data_field, np.ndarray) and data_field.dtype == object:
data_field = _unwrap_singleton(data_field)
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
# 只收数值矩阵
if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object:
candidates.append((f"{k}.data", data_field))
continue
# --- 3) object array尝试 item() 解包后再看是不是二维数值矩阵/struct ---
if isinstance(v, np.ndarray) and v.dtype == object:
vv = _unwrap_singleton(v)
# 解包后若是二维数值矩阵
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
candidates.append((k, vv))
continue
# 解包后若是 struct再取 data
data2 = _try_get_struct_field(vv, "data")
if data2 is not None:
data2 = _unwrap_singleton(data2)
if isinstance(data2, np.ndarray) and data2.ndim == 2:
candidates.append((f"{k}.data", data2))
continue
if not candidates:
raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data{mat_path}")
# 选一个最像 EEG 的优先含32/64通道维度的
def score(arr: np.ndarray) -> int:
s = 0
if 64 in arr.shape: s += 10
if 32 in arr.shape: s += 9
if 128 in arr.shape: s += 8
if 129 in arr.shape: s += 7
s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG
return s
candidates.sort(key=lambda kv: score(kv[1]), reverse=True)
key, eeg = candidates[0]
eeg = _unwrap_singleton(eeg)
# 如果还是 object dtype尽力转成 float
if isinstance(eeg, np.ndarray) and eeg.dtype == object:
# 有时 object 里其实是数值
eeg = np.array(eeg, dtype=np.float32)
else:
eeg = np.asarray(eeg, dtype=np.float32)
if eeg.ndim != 2:
raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}")
# 统一为 (T, C)
# 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
eeg = eeg.T
elif eeg.shape[1] in (32, 64, 128, 129):
# 如果第一维也是这些数且更小,可能是(C,T)
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
eeg = eeg.T
return eeg
def ensure_32_channels(eeg: np.ndarray) -> np.ndarray:
"""
输入 (T, C),输出 (T, 32)
- 若C=64按 IDX64_TO_32 选32通道
- 若C=32直接返回
"""
if eeg.ndim != 2:
raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}")
C = eeg.shape[1]
if C == 64:
idx = np.asarray(IDX64_TO_32, dtype=np.int64)
if idx.min() < 0 or idx.max() >= 64:
raise RuntimeError(f"IDX64_TO_32 越界min={idx.min()}, max={idx.max()} (要求0~63)")
return eeg[:, idx]
if C == 32:
return eeg
raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。")
# =========================================================
# 3) 特征提取DE + PSD(var近似)
# =========================================================
class FeatureExtractor32:
"""
只针对32通道输出维度
- USE_EXTENDED_FEATURES=TrueDE(32*5) + PSD(32*5) = 320
- 否则DE(32*5) = 160
"""
def __init__(
self,
fs: int = SAMPLING_RATE,
window_size: int = WINDOW_SIZE,
stride: int = STRIDE,
filter_order: int = 4,
zero_phase: bool = False,
) -> None:
self.fs = fs
self.window_size = window_size
self.stride = stride
self.filter_order = filter_order
self.zero_phase = zero_phase
self._sos = {}
for bn in BAND_NAMES:
low, high = BANDS[bn]
self._sos[bn] = signal.butter(
self.filter_order, [low, high],
btype="band", fs=self.fs, output="sos"
)
def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]:
out = {}
for bn in BAND_NAMES:
sos = self._sos[bn]
if self.zero_phase:
out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32)
else:
out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32)
return out
def extract(self, eeg32: np.ndarray) -> np.ndarray:
"""
eeg32: (T, 32)
return: feats (N_slices, feat_dim)
"""
if eeg32.ndim != 2 or eeg32.shape[1] != 32:
raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}")
bands_data = self._filter_bands(eeg32)
T = eeg32.shape[0]
feats = []
for start in range(0, T - self.window_size, self.stride):
end = start + self.window_size
de_list = []
psd_list = []
for bn in BAND_NAMES:
seg = bands_data[bn][start:end, :] # (W, 32)
var = np.var(seg, axis=0, ddof=1) # (32,)
# DE
de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS))
de_list.append(de)
if USE_EXTENDED_FEATURES:
# PSD近似log(var)
psd_list.append(np.log(var + EPS))
de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,)
if USE_EXTENDED_FEATURES:
psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,)
f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32)
else:
f = de_feat.astype(np.float32)
feats.append(f)
if not feats:
raise RuntimeError("EEG长度不足以切片请检查T是否太短或调整WINDOW_SIZE/STRIDE")
return np.stack(feats, axis=0).astype(np.float32)
# =========================================================
# 4) 模型加载 + 推理接口
# =========================================================
def _safe_torch_load(path: str):
try:
return torch.load(path, map_location=DEVICE, weights_only=False)
except TypeError:
return torch.load(path, map_location=DEVICE)
def load_model(model_path: str) -> tuple[FusionNet, dict]:
"""
返回: (model, ckpt_dict)
"""
obj = _safe_torch_load(model_path)
if isinstance(obj, dict) and "model_state" in obj:
ckpt = obj
state = obj["model_state"]
feat_dim = int(obj.get("feat_dim", 320))
else:
ckpt = {}
state = obj
feat_dim = 320
model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE)
model.load_state_dict(state, strict=True)
model.eval()
return model, ckpt
def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict:
"""
接口:传入 EEG文件夹 和 模型路径,返回判断结果 dict
返回字段:
- mat_file: 使用的mat文件
- pred_label: "HC" or "MDD"
- p_mdd_mean: 切片p(MDD)均值
- threshold: subject判定阈值
- n_slices: 切片数
"""
mat_file = _find_first_mat_file(eeg_dir)
# 1) 读EEG (T,C)并保证变成32通道
eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C)
eeg32 = ensure_32_channels(eeg) # (T,32)
# 2) 提特征 (N,feat_dim)
extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE)
feats = extractor.extract(eeg32) # (N, dim)
# 3) 加载模型
model, ckpt = load_model(model_path)
# 4) 可选归一化若ckpt里保存的mean/std维度刚好匹配
mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None
std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None
if mean is not None and std is not None:
mean = np.asarray(mean, dtype=np.float32)
std = np.asarray(std, dtype=np.float32)
if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]:
feats = (feats - mean) / (std + 1e-8)
# 不匹配就跳过
# 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理
# 5) 推理:对所有切片算 p(MDD)取均值做subject-level
x = torch.from_numpy(feats).to(DEVICE)
with torch.no_grad():
cls_out, _ = model(x)
prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy()
p_mdd_mean = float(np.mean(prob_mdd))
thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD)
pred_is_mdd = (p_mdd_mean >= thr)
pred_label = "MDD" if pred_is_mdd else "HC"
return {
"mat_file": mat_file,
"pred_label": pred_label,
"p_mdd_mean": p_mdd_mean,
"threshold": thr,
"n_slices": int(feats.shape[0]),
}
# =========================================================
# 5) CLI命令行运行入口
# =========================================================
def main():
parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).")
parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹自动读取第一个.mat")
parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径")
args = parser.parse_args()
res = predict_hc_mdd(args.eeg_dir, args.model_path)
print("\n========== 推理结果 ==========")
print(f"MAT文件: {res['mat_file']}")
print(f"切片数量: {res['n_slices']}")
print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}")
print(f"阈值thr: {res['threshold']:.4f}")
print(f"预测结果: {res['pred_label']}")
print("==============================\n")
if __name__ == "__main__":
main()

Binary file not shown.

Binary file not shown.

BIN
algorithm_V1/out/EEG.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 683 KiB

View File

@@ -0,0 +1,9 @@
中央区α/β波比值:0.5
额区α/β波比值:0.6
顶区α/β波比值:0.6
中央区θ/β波比值:0.9
顶区θ/β波比值:0.9
前额叶α波不对称性:-0.0
个体化α峰值频率:8.5
前额叶θ+δ波功率:74.1
是否推荐治疗:是

Binary file not shown.

After

Width:  |  Height:  |  Size: 247 KiB

BIN
algorithm_V1/out/psd.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 298 KiB

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,6 @@
numpy
scipy
matplotlib
mne
torch
scikit-learn

986
algorithm_V1/runDecoder.py Normal file
View File

@@ -0,0 +1,986 @@
# -*- coding: utf-8 -*-
"""
runDecoder.py - BDF EEG Depression Assessment
功能:
1. 读取 raw_data 文件夹的第一个 .bdf 格式文件
2. 预处理坏通道剔除、50Hz陷波、0.8-40Hz带通、幅值过滤、ICA去伪迹
3. 调用 infer_pth.py 中的 predict_hc_mdd 进行 HC/MDD 分类预测
4. 保存图表EEG、PSD、Topomap
5. 生成 ResultData.txt
"""
import matplotlib
matplotlib.use('Agg')
import numpy as np
import os
import shutil
import scipy.signal as signal
import matplotlib.pyplot as plt
import mne
from mne.preprocessing import ICA
# ==========================
# Config - 预处理参数
# ==========================
# 滤波参数
BANDPASS_LOW = 0.8
BANDPASS_HIGH = 40.0
NOTCH_FREQS = [50, 100] # 工频陷波
# 幅值过滤阈值 (μV)
AMPLITUDE_MIN_UV = -200.0
AMPLITUDE_MAX_UV = 200.0
# ICA 参数
ICA_N_COMPONENTS = 20 # 使用绝对数量而非比例
ICA_RANDOM_STATE = 97
ICA_MAX_ITER = 800
# 坏段检测阈值 (μV)
BAD_SEGMENT_THRESHOLD_UV = 350.0
# 默认采样率
DEFAULT_FS = 250.0
# 画图参数
EEG_PLOT_SECONDS = 10
PSD_FMIN, PSD_FMAX = 0.8, 45.0
FIXED_EEG_IDXS = [23, 47, 39, 6, 2, 21, 35, 57] # 0-based index
FIXED_EEG_LABELS = ["C5", "O1", "TP7", "FPZ", "PO6", "P4", "AF7", "AF3"]
# 频段定义
BANDS_METRICS = {
"Delta": (1.0, 4.0),
"Theta": (4.0, 8.0),
"Alpha": (8.0, 13.0),
"Beta": (13.0, 30.0),
}
TOTAL_POWER_BAND = (1.0, 50.0)
BANDS_TOPOMAP = {
"delta": (1.0, 4.0),
"theta": (4.0, 8.0),
"alpha": (8.0, 13.0),
"beta": (13.0, 30.0),
"broad": (1.0, 30.0),
}
# PSD 参数
PSD_NPERSEG = 1024 # FFT 窗口大小,越大频率分辨率越高
EPS = 1e-12
# 脑地形图颜色范围参数
# 设置为 None 表示自动范围,设置为 (min, max) 固定范围
TOPOMAP_VMIN = None # 例如: -1.0
TOPOMAP_VMAX = None # 例如: 1.0
# 或者使用对称范围(相对于均值的倍数)
TOPOMAP_SYM_SCALE = 1.5 # 颜色范围 = 均值 ± std * SYM_SCALE
# 脑地形图圆形大小参数 (0.08 - 0.15 范围)
# 数值越小圆形越小,越大圆形越大
TOPOMAP_SPHERE_RADIUS = 0.12
# 边界处理参数
# 滤波前 padding 秒数,用于消除边界振铃效应
FILTER_PAD_SEC = 1.0
# ==========================
# 数据文件读取
# ==========================
def load_data_file(file_path: str) -> tuple:
"""根据文件扩展名读取数据,返回 MNE Raw 对象"""
ext = os.path.splitext(file_path)[1].lower()
if ext == ".bdf":
return load_bdf_file(file_path)
elif ext == ".mat":
return load_mat_file(file_path)
else:
raise ValueError(f"不支持的文件格式: {ext}")
def load_bdf_file(bdf_path: str) -> tuple:
"""读取 .bdf 格式文件,返回 MNE Raw 对象"""
print(f"[INFO] Reading BDF file: {bdf_path}")
raw = mne.io.read_raw_bdf(bdf_path, preload=True, verbose=False)
try:
raw.set_montage("standard_1020", on_missing="ignore")
except Exception as e:
print(f"[WARN] Failed to set standard_1020 montage: {e}")
sfreq = raw.info['sfreq']
ch_names = raw.ch_names
n_channels = len(ch_names)
duration = raw.times[-1] - raw.times[0]
print(f"[INFO] Channels: {n_channels}, Duration: {duration:.2f}s, SFreq: {sfreq:.2f}Hz")
return raw, sfreq, ch_names
def load_mat_file(mat_path: str) -> tuple:
"""读取 .mat 格式文件,返回 MNE Raw 对象"""
print(f"[INFO] Reading MAT file: {mat_path}")
import scipy.io
mat = scipy.io.loadmat(mat_path)
eeg = mat['eeg'][0, 0]
# 提取数据
data = eeg['data'] # (T, C)
if data.shape[0] < data.shape[1]:
data = data.T # 确保是 (T, C)
data = data.astype(np.float64) # 确保是 float
# 提取采样率
sfreq = float(eeg['sample_rate'][0, 0])
# 提取通道名称
ch_names_raw = eeg['electrode_name']
if ch_names_raw.ndim == 2:
ch_names = [str(ch[0]) if isinstance(ch[0], np.bytes_) else str(ch[0]) for ch in ch_names_raw[0]]
else:
ch_names = [f"EEG{i+1}" for i in range(data.shape[1])]
n_channels = data.shape[1]
n_samples = data.shape[0]
duration = n_samples / sfreq
print(f"[INFO] Channels: {n_channels}, Duration: {duration:.2f}s, SFreq: {sfreq:.2f}Hz")
# 创建 MNE Raw 对象
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=["eeg"] * n_channels)
raw = mne.io.RawArray(data.T, info, verbose=False) # (T, C) -> (C, T)
# 尝试设置通道位置
try:
electrode_xyz = eeg['electrode_xyz'] # (64, 3)
if electrode_xyz.shape[0] == n_channels:
ch_pos = {}
for i, name in enumerate(ch_names):
ch_pos[name] = electrode_xyz[i] / 1000.0 # 转换为米
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
info.set_montage(montage)
print("[INFO] Applied electrode positions from mat file")
else:
raw.set_montage("standard_1020", on_missing="ignore")
except Exception as e:
print(f"[WARN] Failed to set montage from mat file: {e}")
try:
raw.set_montage("standard_1020", on_missing="ignore")
except:
pass
return raw, sfreq, ch_names
# ==========================
# 坏通道检测
# ==========================
def detect_bad_channels(raw: mne.io.RawArray, z_thresh: float = 3.0) -> list:
"""检测坏通道:全零/常数通道 + MAD z-score 离群通道"""
data = raw.get_data()
ch_names = raw.ch_names
bad_chs = []
ptp = np.ptp(data, axis=1)
std = np.std(data, axis=1)
for i, (p, s) in enumerate(zip(ptp, std)):
if p < 1e-12 or s < 1e-12:
bad_chs.append(ch_names[i])
valid_mask = np.array([ch not in bad_chs for ch in ch_names])
if valid_mask.sum() > 2:
valid_ptp = ptp[valid_mask]
med = np.median(valid_ptp)
mad = np.median(np.abs(valid_ptp - med)) + 1e-30
z = np.abs(ptp - med) / (mad * 1.4826)
for i, zv in enumerate(z):
if zv > z_thresh and ch_names[i] not in bad_chs:
bad_chs.append(ch_names[i])
if bad_chs:
print(f"[INFO] Bad channels detected: {bad_chs}")
else:
print("[INFO] No bad channels detected")
return bad_chs
# ==========================
# 坏段标注
# ==========================
def annotate_bad_segments(raw: mne.io.RawArray, peak_to_peak_uv: float = 350.0):
"""简单坏段检测按1秒窗口计算峰峰值超过阈值标为 bad"""
peak_to_peak_v = peak_to_peak_uv * 1e-6
win = int(raw.info["sfreq"] * 1.0)
step = int(raw.info["sfreq"] * 0.5)
data = raw.get_data()
n_times = data.shape[1]
onsets = []
durations = []
for start in range(0, n_times - win, step):
seg = data[:, start:start + win]
ptp = np.ptp(seg, axis=1)
if np.any(ptp > peak_to_peak_v):
onsets.append(start / raw.info["sfreq"])
durations.append(win / raw.info["sfreq"])
if len(onsets) > 0:
ann = mne.Annotations(onset=onsets, duration=durations, description=["BAD_SEG"] * len(onsets))
raw.set_annotations(ann)
print(f"[INFO] Annotated {len(onsets)} bad segments")
# ==========================
# 核心预处理函数
# ==========================
def preprocess_bdf(raw: mne.io.RawArray) -> mne.io.RawArray:
"""BDF 数据预处理流程"""
print("[INFO] Starting preprocessing pipeline...")
# 1) 裁剪首尾 2s
crop_sec = 2.0
t_start = crop_sec
t_end = raw.times[-1] - crop_sec
if t_end > t_start:
raw = raw.crop(tmin=t_start, tmax=t_end)
print(f"[INFO] Cropped: removed first/last {crop_sec}s")
# 2) 去直流偏置
data = raw.get_data()
data -= data.mean(axis=1, keepdims=True)
raw._data = data
print("[INFO] Removed DC offset")
# 3) 坏通道检测与插值
bad_chs = detect_bad_channels(raw)
if bad_chs:
raw.info["bads"] = bad_chs
try:
raw_tmp = raw.copy()
raw_tmp.set_montage(raw.get_montage(), on_missing="ignore")
raw_tmp.interpolate_bads(reset_bads=True, verbose=False)
raw = raw_tmp
print(f"[INFO] Bad channels interpolated: {bad_chs}")
except Exception as e:
print(f"[WARN] Bad channel interpolation failed: {e}")
raw.info["bads"] = []
# 4) 50Hz 陷波滤波
print(f"[INFO] Applying notch filter: {NOTCH_FREQS}Hz")
raw.notch_filter(NOTCH_FREQS, fir_design="firwin", verbose=False)
# 5) 0.8-40Hz 带通滤波 (使用 padding 消除边界振铃)
print(f"[INFO] Applying bandpass filter: {BANDPASS_LOW}-{BANDPASS_HIGH}Hz (with {FILTER_PAD_SEC}s padding)")
pad_sec = FILTER_PAD_SEC
raw_length = raw.times[-1]
pad_start = max(0, pad_sec)
pad_end = max(0, pad_sec)
if raw_length > pad_start + pad_end + 1.0:
raw.filter(BANDPASS_LOW, BANDPASS_HIGH, fir_design="firwin",
pad='reflect', verbose=False)
raw = raw.crop(tmin=pad_sec, tmax=raw_length - pad_sec)
print(f"[INFO] Removed {pad_sec}s padding from each side after filtering")
else:
raw.filter(BANDPASS_LOW, BANDPASS_HIGH, fir_design="firwin", verbose=False)
print(f"[WARN] Data too short ({raw_length:.1f}s), skipping padding")
# 6) 幅值过滤
print(f"[INFO] Applying amplitude filter: [{AMPLITUDE_MIN_UV}, {AMPLITUDE_MAX_UV}] μV")
amplitude_thresh_v = AMPLITUDE_MAX_UV * 1e-6
d = raw.get_data()
mask = np.abs(d) > amplitude_thresh_v
n_clipped = int(mask.sum())
if n_clipped > 0:
d[mask] = 0.0
raw._data = d
print(f"[INFO] Amplitude clipping: {n_clipped} samples exceeded ±200μV, set to 0")
# 7) 坏段标注
annotate_bad_segments(raw, peak_to_peak_uv=BAD_SEGMENT_THRESHOLD_UV)
# 8) ICA 去伪迹
print("[INFO] Running ICA for artifact removal...")
ica = ICA(n_components=ICA_N_COMPONENTS, random_state=ICA_RANDOM_STATE,
max_iter=ICA_MAX_ITER, method="fastica")
ica.fit(raw, reject_by_annotation=True, verbose=False)
try:
eog_inds, _ = ica.find_bads_eog(raw, verbose=False)
if eog_inds:
ica.exclude.extend(eog_inds)
print(f"[INFO] ICA exclude EOG components: {eog_inds}")
except Exception as e:
print(f"[WARN] ICA EOG detection skipped: {e}")
raw_clean = ica.apply(raw.copy(), verbose=False)
# 9) ICA 后再次去直流
d = raw_clean.get_data()
d -= d.mean(axis=1, keepdims=True)
raw_clean._data = d
print("[INFO] Preprocessing completed")
return raw_clean
# ==========================
# 输出目录管理
# ==========================
def ensure_outdir(out_root: str) -> str:
"""确保输出目录存在,并清空旧文件(保留 ResultData.txt"""
if os.path.exists(out_root):
for filename in os.listdir(out_root):
if filename == "ResultData.txt":
continue
file_path = os.path.join(out_root, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"[WARN] Failed to delete {file_path}: {e}")
else:
os.makedirs(out_root, exist_ok=True)
return out_root
# ==========================
# 通道分区
# ==========================
def _norm_name(s: str) -> str:
return str(s).strip().upper().replace(" ", "")
def build_channel_index_map(ch_names, n_channels: int):
if not ch_names or len(ch_names) != n_channels:
return {}
return {_norm_name(nm): i for i, nm in enumerate(ch_names)}
def pick_indices_by_names(name_to_idx, names):
idx = []
for n in names:
nn = _norm_name(n)
if nn in name_to_idx:
idx.append(name_to_idx[nn])
return sorted(list(set(idx)))
def _fallback_region_indices(n_channels: int):
a = int(n_channels * 0.33)
b = int(n_channels * 0.66)
return (
list(range(0, a)), # frontal
list(range(a, b)), # central
list(range(b, n_channels)), # parietal
list(range(0, max(2, a // 2))), # prefrontal
list(range(b, n_channels)), # posterior
[i for i in range(n_channels) if i % 2 == 0], # left
[i for i in range(n_channels) if i % 2 == 1], # right
)
def get_region_indices(name_to_idx, n_channels: int):
if not name_to_idx:
return _fallback_region_indices(n_channels)
central_names = ["CZ","C1","C2","C3","C4","C5","C6","CP1","CP2","CP3","CP4","CP5","CP6","FC1","FC2","FC3","FC4","FC5","FC6"]
frontal_names = ["FZ","F1","F2","F3","F4","F5","F6","F7","F8","AF3","AF4","AF7","AF8","FPZ","FP1","FP2","FCZ"]
parietal_names = ["PZ","P1","P2","P3","P4","P5","P6","POZ","PO3","PO4","PO5","PO6","PO7","PO8","CPZ"]
prefrontal_names = ["FP1","FP2","FPZ","AF3","AF4","AF7","AF8"]
posterior_names = ["O1","O2","OZ","PO7","PO8","PO3","PO4","PZ","P3","P4","P1","P2"]
left_names = ["FP1","AF3","AF7","F3","F5","F7"]
right_names = ["FP2","AF4","AF8","F4","F6","F8"]
central = pick_indices_by_names(name_to_idx, central_names)
frontal = pick_indices_by_names(name_to_idx, frontal_names)
parietal = pick_indices_by_names(name_to_idx, parietal_names)
prefrontal = pick_indices_by_names(name_to_idx, prefrontal_names)
posterior = pick_indices_by_names(name_to_idx, posterior_names)
left = pick_indices_by_names(name_to_idx, left_names)
right = pick_indices_by_names(name_to_idx, right_names)
if not (central and frontal and parietal and prefrontal and posterior):
fb = _fallback_region_indices(n_channels)
frontal2, central2, parietal2, prefrontal2, posterior2, left2, right2 = fb
frontal = frontal if frontal else frontal2
central = central if central else central2
parietal = parietal if parietal else parietal2
prefrontal = prefrontal if prefrontal else prefrontal2
posterior = posterior if posterior else posterior2
left = left if left else left2
right = right if right else right2
return frontal, central, parietal, prefrontal, posterior, left, right
# ==========================
# PSD 和频段功率计算
# ==========================
def welch_psd(eeg_tc, fs):
"""计算 PSD"""
nperseg = min(PSD_NPERSEG, eeg_tc.shape[0])
noverlap = int(nperseg * 0.75)
freqs, pxx = signal.welch(
eeg_tc, fs=fs, nperseg=nperseg, noverlap=noverlap,
axis=0, scaling="density"
)
return freqs, pxx
def band_power_from_psd(freqs, pxx_fc, band):
lo, hi = band
m = (freqs >= lo) & (freqs < hi)
if not np.any(m):
return np.zeros((pxx_fc.shape[1],), dtype=np.float32)
return np.trapz(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
def region_mean_power(freqs, pxx_fc, idx, band) -> float:
if not idx:
return 0.0
pw = band_power_from_psd(freqs, pxx_fc, band)
return float(np.mean(pw[idx]))
def compute_iaf(freqs, pxx_fc, posterior_idx):
lo, hi = BANDS_METRICS["Alpha"]
m = (freqs >= lo) & (freqs <= hi)
if not np.any(m) or not posterior_idx:
return 0.0
spec = np.mean(pxx_fc[:, posterior_idx], axis=1)
sub = spec[m]
fsub = freqs[m]
return float(fsub[int(np.argmax(sub))])
# ==========================
# 画图函数
# ==========================
def plot_eeg_waveforms(data_uv_tc, fs, ch_names, out_dir, seconds=10, t_start_sec=30.0):
"""画 EEG 波形图(固定通道)"""
T, C = data_uv_tc.shape
start_sample = int(t_start_sec * fs)
end_sample = int(min(T, start_sample + seconds * fs))
if start_sample >= T:
start_sample = max(0, T - int(seconds * fs))
end_sample = T
print(f"[WARN] t_start_sec={t_start_sec}s exceeds data, using last {seconds}s")
seg_samples = end_sample - start_sample
x = np.arange(seg_samples) / fs + t_start_sec
# 过滤有效索引
idxs = [i for i in FIXED_EEG_IDXS if 0 <= i < C]
if len(idxs) < len(FIXED_EEG_IDXS):
missing = [i for i in FIXED_EEG_IDXS if not (0 <= i < C)]
print(f"[WARN] Some indices out of range (C={C}): {missing}")
if len(idxs) == 0:
raise RuntimeError(f"No valid indices for data (C={C})")
picked_names = []
for idx in idxs:
pos = FIXED_EEG_IDXS.index(idx)
std_label = FIXED_EEG_LABELS[pos] if pos < len(FIXED_EEG_LABELS) else f"CH{idx}"
if ch_names and idx < len(ch_names):
picked_names.append(std_label)
else:
picked_names.append(std_label)
fig_h = 1.4 * len(idxs) + 1
fig, axes = plt.subplots(len(idxs), 1, figsize=(10, fig_h), sharex=True)
if len(idxs) == 1:
axes = [axes]
seg = data_uv_tc[start_sample:end_sample, idxs].T
lo = float(np.percentile(seg, 1))
hi = float(np.percentile(seg, 99))
m = max(abs(lo), abs(hi), 50.0)
for ax, ch_idx, nm in zip(axes, idxs, picked_names):
y = data_uv_tc[start_sample:end_sample, ch_idx]
ax.plot(x, y, linewidth=1.2)
ax.set_ylabel("uV")
ax.set_title(nm, loc="left", fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_ylim(-m, m)
axes[-1].set_xlabel("Time (s)")
plt.tight_layout()
out_path = os.path.join(out_dir, "EEG.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] EEG waveform saved: {out_path}")
def plot_psd(eeg_uV_tc, fs, ch_names, out_dir):
"""画 PSD 图"""
C = eeg_uV_tc.shape[1]
chosen_idx = []
if ch_names:
mp = {n.upper(): i for i, n in enumerate(ch_names)}
for p in ["C3", "C4", "CZ"]:
if p in mp:
chosen_idx.append(mp[p])
if len(chosen_idx) < 3:
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
stds.sort(key=lambda x: x[1], reverse=True)
for i, _ in stds:
if i not in chosen_idx:
chosen_idx.append(i)
if len(chosen_idx) == 3:
break
chosen_name = [ch_names[i] for i in chosen_idx]
else:
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
stds.sort(key=lambda x: x[1], reverse=True)
chosen_idx = [i for i, _ in stds[:3]]
chosen_name = [f"CH{i}" for i in chosen_idx]
# 增大 nperseg 提高频率分辨率
nperseg = min(PSD_NPERSEG, eeg_uV_tc.shape[0])
noverlap = int(nperseg * 0.75)
fig = plt.figure(figsize=(7.5, 4.8))
for idx, nm in zip(chosen_idx, chosen_name):
f, pxx = signal.welch(eeg_uV_tc[:, idx], fs=fs, nperseg=nperseg, noverlap=noverlap)
mask = (f >= PSD_FMIN) & (f <= PSD_FMAX)
p_db = 10 * np.log10(pxx[mask] + 1e-20)
plt.plot(f[mask], p_db, linewidth=1.8, label=nm)
plt.xlabel("Hz")
plt.ylabel("Power (dB)")
plt.title("PSD")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
out_path = os.path.join(out_dir, "psd.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] PSD saved: {out_path}")
def _get_standard_1020_channel_indices(raw):
"""获取符合 standard_1020 montage 的通道索引和名称"""
try:
standard_montage = mne.channels.make_standard_montage("standard_1020")
standard_names_upper = {ch.upper() for ch in standard_montage.ch_names}
standard_name_map = {ch.upper(): ch for ch in standard_montage.ch_names}
data_ch_names = raw.ch_names
exclude_names = {"A1", "A2", "M1", "M2", "LE", "RE", "LM", "RM"}
valid_indices = []
valid_names = []
for i, name in enumerate(data_ch_names):
name_upper = name.upper()
if name_upper in standard_names_upper and name_upper not in exclude_names:
valid_indices.append(i)
valid_names.append(standard_name_map[name_upper])
print(f"[INFO] Found {len(valid_indices)}/{len(data_ch_names)} channels matching standard_1020")
return valid_indices, valid_names
except Exception as e:
print(f"[WARN] Failed to get standard_1020 channels: {e}")
return None, None
def compute_band_powers_for_topomap(raw, bands):
"""计算各频段功率,只使用 standard_1020 中有位置的通道"""
# 获取 standard_1020 montage 和位置信息
standard_montage = mne.channels.make_standard_montage("standard_1020")
std_names_upper = {ch.upper() for ch in standard_montage.ch_names}
ch_pos_map = standard_montage.get_positions()['ch_pos']
data_ch_names = raw.ch_names
exclude = {"A1", "A2", "M1", "M2", "LE", "RE", "LM", "RM"}
# 只保留有位置信息的通道
valid_indices = []
valid_names = []
for i, name in enumerate(data_ch_names):
name_upper = name.upper()
if name_upper in std_names_upper and name_upper not in exclude:
if name_upper in ch_pos_map: # 必须有位置
valid_indices.append(i)
valid_names.append(name)
if len(valid_indices) < 8:
return None
data = raw.get_data()
data_standard = data[valid_indices, :]
fs = raw.info["sfreq"]
n_fft = min(PSD_NPERSEG, data_standard.shape[1])
n_overlap = int(n_fft * 0.75)
psds, freqs = mne.time_frequency.psd_array_welch(
data_standard, sfreq=fs,
fmin=min(v[0] for v in bands.values()),
fmax=max(v[1] for v in bands.values()),
n_fft=n_fft, n_overlap=n_overlap,
average="mean", verbose=False
)
out = {"_valid_names": valid_names}
print(f"[DEBUG] PSD: fs={fs}Hz, n_fft={n_fft}, freq_res={fs/n_fft:.3f}Hz/bin")
for k, (fmin, fmax) in bands.items():
idx = np.where((freqs >= fmin) & (freqs < fmax))[0]
if len(idx) == 0:
out[k] = np.zeros(len(valid_indices), dtype=np.float32)
print(f"[DEBUG] {k.upper()}: NO freq bins in [{fmin}-{fmax}]Hz")
continue
print(f"[DEBUG] {k.upper()}: freq bins {freqs[idx[0]]:.2f}-{freqs[idx[-1]]:.2f}Hz (bins {idx[0]}-{idx[-1]}, count={len(idx)})")
# 使用线性功率值 (V^2 -> uV^2: * 1e12)
bp = np.trapz(psds[:, idx], freqs[idx], axis=1) * 1e12
out[k] = bp
print(f"[DEBUG] {k.upper()}: power range [{bp.min():.4f}, {bp.max():.4f}] uV^2, mean={bp.mean():.4f}")
print(f"[INFO] Band powers computed for {len(valid_names)} channels with positions")
return out
def _create_topomap_raw(ch_names):
"""创建只有 standard_1020 通道位置信息的临时 Raw 对象"""
standard_montage = mne.channels.make_standard_montage("standard_1020")
ch_pos_map = standard_montage.get_positions()['ch_pos']
valid_ch_names = []
valid_positions = []
for name in ch_names:
name_upper = name.upper()
if name_upper in ch_pos_map:
valid_ch_names.append(name)
valid_positions.append(ch_pos_map[name_upper])
if len(valid_ch_names) < 8:
return None
ch_pos = {name: pos for name, pos in zip(valid_ch_names, valid_positions)}
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
info = mne.create_info(ch_names=valid_ch_names, sfreq=250.0, ch_types=["eeg"] * len(valid_ch_names))
info.set_montage(montage)
dummy_data = np.zeros((len(valid_ch_names), 1))
return mne.io.RawArray(dummy_data, info, verbose=False)
def plot_average_topomap(band_values, out_dir):
"""绘制平均拓扑图"""
valid_names = band_values.get("_valid_names", [])
if not valid_names:
return
values = band_values["broad"]
temp_raw = _create_topomap_raw(valid_names)
if temp_raw is None:
return
vmin, vmax = _compute_topomap_vlim([values])
fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.6))
im, _ = mne.viz.plot_topomap(
values, temp_raw.info, axes=ax, show=False, contours=0,
sphere=(0, 0, 0, TOPOMAP_SPHERE_RADIUS), extrapolate='head', border='mean',
cmap='turbo'
)
im.set_clim(vmin=vmin, vmax=vmax)
ax.set_title("0.8-30 Hz", fontsize=12)
plt.colorbar(im, ax=ax, shrink=0.85)
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "average_topomap.png"), dpi=200)
plt.close(fig)
print(f"[OK] average_topomap saved")
def plot_band_topomaps(band_values, out_dir):
"""绘制分频段拓扑图"""
valid_names = band_values.get("_valid_names", [])
if not valid_names:
return
order = [
("delta", "δ (1-4Hz)"),
("theta", "θ (4-8Hz)"),
("alpha", "α (8-13Hz)"),
("beta", "β (13-30Hz)"),
("broad", "1-30 Hz"),
]
temp_raw = _create_topomap_raw(valid_names)
if temp_raw is None:
return
all_values = [band_values[k] for k, _ in order]
vmin, vmax = _compute_topomap_vlim(all_values)
fig, axes = plt.subplots(1, 5, figsize=(16, 4.2))
ims = []
for ax, (k, title) in zip(axes, order):
im, _ = mne.viz.plot_topomap(
band_values[k], temp_raw.info, axes=ax, show=False, contours=0,
sphere=(0, 0, 0, TOPOMAP_SPHERE_RADIUS), extrapolate='head', border='mean',
cmap='turbo'
)
im.set_clim(vmin=vmin, vmax=vmax)
ax.set_title(title, fontsize=11)
ims.append(im)
fig.subplots_adjust(left=0.02, right=0.85, top=0.88, bottom=0.05, wspace=0.35)
cax = fig.add_axes([0.87, 0.15, 0.015, 0.7])
fig.colorbar(ims[-1], cax=cax)
plt.savefig(os.path.join(out_dir, "topomaps.png"), dpi=200)
plt.close(fig)
print(f"[OK] topomaps saved")
def _compute_topomap_vlim(values):
"""计算脑地形图颜色范围"""
v_all = np.concatenate(values) if isinstance(values, list) else np.array(values)
if TOPOMAP_VMIN is not None and TOPOMAP_VMAX is not None:
return TOPOMAP_VMAX - 60, TOPOMAP_VMAX # 保持 50 的范围
if TOPOMAP_SYM_SCALE is not None and TOPOMAP_SYM_SCALE > 0:
mean_val = np.mean(v_all)
std_val = np.std(v_all)
return mean_val - std_val * TOPOMAP_SYM_SCALE, mean_val + std_val * TOPOMAP_SYM_SCALE
# 统一 vmax使用所有频段中的最大值
# vmin = 0这样低功率频段会接近 0白色/冷色),高功率频段突出
vmax = np.max(v_all)
vmin = 0
return vmin, vmax
# ==========================
# 预测接口
# ==========================
def _predict_label_by_model(model_path: str, data_path: str) -> dict:
"""调用 infer_pth.py 进行预测"""
try:
from infer_pth import predict_hc_mdd
except Exception as e:
raise RuntimeError(f"无法导入 predict_hc_mdd: {e}")
import tempfile
import scipy.io
ext = os.path.splitext(data_path)[1].lower()
if ext == ".mat":
# 直接使用 mat 文件
result = predict_hc_mdd(os.path.dirname(data_path), model_path)
elif ext == ".bdf":
# 转换为 mat 格式
raw = mne.io.read_raw_bdf(data_path, preload=True, verbose=False)
data, times = raw[:]
sfreq = raw.info['sfreq']
ch_names = raw.ch_names
with tempfile.TemporaryDirectory() as temp_dir:
mat_path = os.path.join(temp_dir, "preprocessed_eeg.mat")
scipy.io.savemat(mat_path, {
'eeg': {
'data': (data * 1e6).T,
'sample_rate': sfreq,
'electrode_name': ch_names
}
})
result = predict_hc_mdd(temp_dir, model_path)
else:
raise ValueError(f"不支持的文件格式: {ext}")
return result
# ==========================
# 生成 ResultData.txt
# ==========================
def compute_and_save_txt(model_path, bdf_path, out_dir, eeg_uV_tc, fs, ch_names):
"""计算特征指标并保存 ResultData.txt"""
# 获取预测结果
pred_result = _predict_label_by_model(model_path, bdf_path)
pred_label = pred_result.get("pred_label", "UNKNOWN")
recommend = "" if pred_label == "MDD" else ""
T, C = eeg_uV_tc.shape
mp = build_channel_index_map(ch_names, C)
frontal_idx, central_idx, parietal_idx, prefrontal_idx, posterior_idx, left_idx, right_idx = \
get_region_indices(mp, C)
freqs, pxx = welch_psd(eeg_uV_tc, fs)
# 计算各频段功率比
central_alpha = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Alpha"])
central_beta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Beta"])
frontal_alpha = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Alpha"])
frontal_beta = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Beta"])
par_alpha = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Alpha"])
par_beta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Beta"])
central_ab = (central_alpha / (central_beta + EPS)) if central_beta > 0 else 0.0
frontal_ab = (frontal_alpha / (frontal_beta + EPS)) if frontal_beta > 0 else 0.0
par_ab = (par_alpha / (par_beta + EPS)) if par_beta > 0 else 0.0
central_theta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Theta"])
par_theta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Theta"])
central_tb = (central_theta / (central_beta + EPS)) if central_beta > 0 else 0.0
par_tb = (par_theta / (par_beta + EPS)) if par_beta > 0 else 0.0
if not left_idx or not right_idx:
left_idx = [i for i in prefrontal_idx if (i % 2 == 0)]
right_idx = [i for i in prefrontal_idx if (i % 2 == 1)]
left_alpha = region_mean_power(freqs, pxx, left_idx, BANDS_METRICS["Alpha"])
right_alpha = region_mean_power(freqs, pxx, right_idx, BANDS_METRICS["Alpha"])
prefrontal_alpha_asym = float(np.log(right_alpha + EPS) - np.log(left_alpha + EPS))
iaf = compute_iaf(freqs, pxx, posterior_idx)
pre_td = region_mean_power(freqs, pxx, prefrontal_idx,
(BANDS_METRICS["Delta"][0], BANDS_METRICS["Theta"][1]))
pre_total = region_mean_power(freqs, pxx, prefrontal_idx, TOTAL_POWER_BAND)
pre_td_rel = (pre_td / (pre_total + EPS)) * 100.0 if pre_total > 0 else 0.0
def f1(x): return f"{x:.1f}"
txt = (
f"中央区α/β波比值:{f1(central_ab)}\n"
f"额区α/β波比值:{f1(frontal_ab)}\n"
f"顶区α/β波比值:{f1(par_ab)}\n"
f"中央区θ/β波比值:{f1(central_tb)}\n"
f"顶区θ/β波比值:{f1(par_tb)}\n"
f"前额叶α波不对称性:{f1(prefrontal_alpha_asym)}\n"
f"个体化α峰值频率:{f1(iaf)}\n"
f"前额叶θ+δ波功率:{f1(pre_td_rel)}\n"
f"是否推荐治疗:{recommend}\n"
)
out_path = os.path.join(out_dir, "ResultData.txt")
with open(out_path, "w", encoding="utf-8") as f:
f.write(txt)
print(f"[OK] ResultData.txt saved: {out_path}")
# 打印预测结果
print(f"\n========== 预测结果 ==========")
print(f"预测标签: {pred_label}")
print(f"p(MDD)均值: {pred_result.get('p_mdd_mean', 'N/A'):.4f}")
print(f"切片数量: {pred_result.get('n_slices', 'N/A')}")
print(f"==============================\n")
# ==========================
# 主函数
# ==========================
def run_all(model_path: str, bdf_dir: str, out_root: str, seconds: int = EEG_PLOT_SECONDS):
"""主流程"""
if not os.path.exists(bdf_dir):
raise RuntimeError(f"输入目录不存在: {bdf_dir}")
# 支持 .bdf 和 .mat 文件
data_files = [f for f in os.listdir(bdf_dir) if f.lower().endswith((".bdf", ".mat"))]
if not data_files:
raise RuntimeError(f"目录中找不到 .bdf 或 .mat 文件: {bdf_dir}")
data_files.sort()
data_path = os.path.join(bdf_dir, data_files[0])
print(f"[INFO] Processing file: {data_path}")
out_dir = ensure_outdir(out_root)
print(f"[INFO] Output directory: {out_dir}")
raw, sfreq, ch_names = load_data_file(data_path)
raw_clean = preprocess_bdf(raw)
try:
raw_clean.set_montage("standard_1020", on_missing="ignore")
except Exception as e:
print(f"[WARN] Failed to re-apply montage: {e}")
raw_data = raw_clean.get_data()
eeg_uV_tc = (raw_data * 1e6).T.astype(np.float32)
print(f"[INFO] Preprocessed EEG shape: {eeg_uV_tc.shape}")
print("[INFO] Generating figures...")
plot_psd(eeg_uV_tc, sfreq, ch_names, out_dir)
plot_eeg_waveforms(eeg_uV_tc, sfreq, ch_names, out_dir, seconds=seconds)
print("[INFO] Generating topomaps...")
try:
band_vals = compute_band_powers_for_topomap(raw_clean, BANDS_TOPOMAP)
if band_vals is not None:
plot_average_topomap(band_vals, out_dir)
plot_band_topomaps(band_vals, out_dir)
except Exception as e:
print(f"[WARN] Topomap generation failed: {e}")
print("[INFO] Running prediction...")
compute_and_save_txt(model_path, data_path, out_dir, eeg_uV_tc, sfreq, ch_names)
print("[DONE] All tasks completed.")
return out_dir
# ==========================
# 命令行入口
# ==========================
if __name__ == "__main__":
import multiprocessing
multiprocessing.freeze_support()
import argparse
import sys
def get_resource_path(relative_path):
"""获取资源绝对路径"""
if getattr(sys, 'frozen', False):
base_path = os.path.dirname(sys.executable)
else:
base_path = os.path.dirname(os.path.abspath(__file__))
return os.path.join(base_path, relative_path)
# 默认路径
DEFAULT_MODEL = get_resource_path(os.path.join("model", "Model_1.pth"))
if getattr(sys, 'frozen', False):
EXE_DIR = os.path.dirname(sys.executable)
else:
EXE_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_BDF_DIR = os.path.join(EXE_DIR, "raw_data")
DEFAULT_OUT = os.path.join(EXE_DIR, "out")
parser = argparse.ArgumentParser(description="EEG Depression Assessment")
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL, help="模型文件路径 (.pth)")
parser.add_argument("--bdf_dir", type=str, default=DEFAULT_BDF_DIR, help="输入文件夹路径 (包含 .bdf 或 .mat 文件)")
parser.add_argument("--out_root", type=str, default=DEFAULT_OUT, help="结果输出目录")
parser.add_argument("--seconds", type=int, default=10, help="画波形图的秒数")
args = parser.parse_args()
print(f"[*] 运行配置:")
print(f" - Model : {args.model_path}")
print(f" - Input : {args.bdf_dir}")
print(f" - Output: {args.out_root}")
if not os.path.exists(args.bdf_dir):
print(f"[ERROR] 输入目录不存在: {args.bdf_dir}")
if not os.path.exists(args.model_path):
print(f"[ERROR] 模型文件不存在: {args.model_path}")
run_all(args.model_path, args.bdf_dir, args.out_root, seconds=args.seconds)