Files
Depression_TMS/algorithm_V0/algorithm_fromXjtu/runDecoder.py
2026-06-01 13:18:36 +08:00

910 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
from __future__ import annotations
"""
run_metrics_and_figs.py
1) 自动读取 mat_dir 中排序后的第一个 .mat
2) 调用模型预测HC/MDD并写 ResultData.txt
3) 同时保存图片EEG.png / psd.png / average_topomap.png / topomaps.png
"""
import matplotlib
matplotlib.use('Agg')
import numpy as np
import os
import shutil
import scipy.io
import scipy.signal as signal
import matplotlib.pyplot as plt
import mne
from mne.preprocessing import ICA
# ==========================
# Config
# ==========================
PREPROCESS_BANDPASS = (0.8, 30.0)
PREPROCESS_NOTCH = [50, 100]
PREPROCESS_ICA_N = 0.99
PREPROCESS_ICA_SEED = 97
PREPROCESS_APPLY_AVG_REF = True
PREPROCESS_BAD_PTP_UV = 350.0 # 坏段阈值 (μV)
DEFAULT_FS = 250.0
EEG_PLOT_SECONDS = 10
PSD_FMIN, PSD_FMAX = 0.8, 45.0
EPS = 1e-12
FIXED_EEG_IDXS = [23, 47, 39, 6, 2, 21, 35, 57] # 0-based index, 按重要性排序
FIXED_EEG_LABELS = ["C5", "O1", "TP7", "FPZ", "PO6", "P4", "AF7", "AF3"]
BANDS_METRICS = {
"Delta": (1.0, 4.0),
"Theta": (4.0, 8.0),
"Alpha": (8.0, 13.0),
"Beta": (13.0, 30.0),
}
TOTAL_POWER_BAND = (1.0, 50.0)
BANDS_TOPOMAP = {
"delta": (0.8, 3.9),
"theta": (4.0, 7.9),
"alpha": (8.0, 12.9),
"beta": (13.0, 30.0),
"broad": (0.8, 30.0),
}
# ==========================
# 预处理逻辑
# ==========================
def annotate_bad_segments(raw, peak_to_peak_uv=250.0):
"""
简单坏段检测:按固定窗口计算峰峰值,超过阈值标为 bad。
"""
peak_to_peak_v = peak_to_peak_uv * 1e-6
win = int(raw.info["sfreq"] * 1.0)
step = int(raw.info["sfreq"] * 0.5)
data = raw.get_data()
n_times = data.shape[1]
onsets = []
durations = []
descriptions = []
for start in range(0, n_times - win, step):
seg = data[:, start:start + win]
ptp = np.ptp(seg, axis=1)
if np.any(ptp > peak_to_peak_v):
onsets.append(start / raw.info["sfreq"])
durations.append(win / raw.info["sfreq"])
descriptions.append("BAD_PTP")
if len(onsets) > 0:
ann = mne.Annotations(onset=onsets, duration=durations, description=descriptions)
raw.set_annotations(ann)
print(f"[INFO] Annotated bad segments: {len(onsets)} windows")
else:
print("[INFO] No bad segments detected by PTP rule")
def run_preprocess_on_raw(raw: mne.io.RawArray) -> mne.io.RawArray:
"""
核心预处理:滤波 + 平均参考 + 坏段标注 + ICA
"""
# 1) 滤波
raw.filter(PREPROCESS_BANDPASS[0], PREPROCESS_BANDPASS[1], fir_design="firwin", verbose=False)
raw.notch_filter(PREPROCESS_NOTCH, fir_design="firwin", verbose=False)
# 2) 平均参考
if PREPROCESS_APPLY_AVG_REF:
raw.set_eeg_reference("average", verbose=False)
# 3) 坏段标注
annotate_bad_segments(raw, peak_to_peak_uv=PREPROCESS_BAD_PTP_UV)
# 4) ICA
ica = ICA(
n_components=PREPROCESS_ICA_N,
random_state=PREPROCESS_ICA_SEED,
max_iter=800,
method="fastica"
)
ica.fit(raw, reject_by_annotation=True, verbose=False)
try:
eog_inds, _ = ica.find_bads_eog(raw, verbose=False)
if eog_inds:
ica.exclude.extend(eog_inds)
print(f"[INFO] ICA exclude EOG comps: {eog_inds}")
except Exception as e:
print(f"[WARN] ICA find_bads_eog skipped: {e}")
raw_clean = ica.apply(raw.copy(), verbose=False)
return raw_clean
def preprocess_mat_file(src_mat_path: str, temp_out_dir: str) -> str:
"""
读取原始mat -> 预处理 -> 保存到 temp_out_dir -> 返回新路径
"""
os.makedirs(temp_out_dir, exist_ok=True)
# 1. 读原始 mat
# 注意:这里我们只要数据部分转成 MNE Raw然后处理再存回
# 复用现有的 load_eeg_from_mat 拿到 ndarray
eeg_uV, fs, ch_names, xyz = load_eeg_from_mat(src_mat_path)
# 转 MNE (注意单位uV -> V)
if not ch_names:
ch_names = [f"CH{i+1}" for i in range(eeg_uV.shape[1])]
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"] * len(ch_names))
raw = mne.io.RawArray(eeg_uV.T * 1e-6, info, verbose=False)
if xyz is not None and isinstance(xyz, np.ndarray):
# 尝试设 montage虽然对滤波不关键但尽量保留信息
try:
ch_pos = {ch_names[i]: xyz[i, :] for i in range(len(ch_names))}
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame="head")
raw.set_montage(montage, on_missing="ignore")
except Exception:
pass
# 2. 执行预处理
print(f"[INFO] Start preprocessing: {src_mat_path}")
raw_clean = run_preprocess_on_raw(raw)
# 3. 存回 .mat (保持结构兼容,以便后续 run_all 读取)
# 这里我们需要读取原始 mat 的结构体,把 data 替换掉
try:
mat_struct = scipy.io.loadmat(src_mat_path, struct_as_record=False, squeeze_me=True)
if "eeg" in mat_struct:
eeg_obj = mat_struct["eeg"]
# 替换数据MNE (V) -> uV -> (T, C)
clean_data_uV = (raw_clean.get_data() * 1e6).T
eeg_obj.data = clean_data_uV
base_name = os.path.basename(src_mat_path)
new_path = os.path.join(temp_out_dir, base_name)
scipy.io.savemat(new_path, {"eeg": eeg_obj}, do_compression=True)
print(f"[INFO] Preprocessed file saved to: {new_path}")
return new_path
except Exception as e:
print(f"[WARN] Failed to preserve original struct structure: {e}")
# Fallback: 如果读原始结构失败,就存一个简单的 mat
clean_data_uV = (raw_clean.get_data() * 1e6).T
out_dict = {
"eeg": {
"data": clean_data_uV,
"sample_rate": fs,
"electrode_name": ch_names,
"electrode_xyz": xyz if xyz is not None else []
}
}
base_name = os.path.basename(src_mat_path)
new_path = os.path.join(temp_out_dir, base_name)
scipy.io.savemat(new_path, out_dict, do_compression=True)
print(f"[INFO] Preprocessed file saved (fallback mode) to: {new_path}")
return new_path
# ==========================
# 输出目录
# ==========================
def ensure_outdir(out_root: str) -> str:
"""
确保输出目录存在,并清空除 ResultData.txt 之外的旧文件。
不再创建 timestamp 子文件夹,直接输出到 out_root。
"""
if os.path.exists(out_root):
# 清空目录,但保留 ResultData.txt
for filename in os.listdir(out_root):
if filename == "ResultData.txt":
continue
file_path = os.path.join(out_root, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"[WARN] Failed to delete {file_path}. Reason: {e}")
else:
os.makedirs(out_root, exist_ok=True)
return out_root
# ==========================
# 单位自动识别:统一到 μV
# ==========================
def _auto_scale_to_uV(data_nt_nc: np.ndarray):
data = np.asarray(data_nt_nc)
p95 = float(np.percentile(np.abs(data), 95))
if p95 <= 0.5:
data_uV = data * 1e6
msg = f"[UNIT] p95={p95:.3g} -> assume V, convert to μV by *1e6"
elif p95 > 5000:
data_uV = data * 1e-3
msg = f"[UNIT] p95={p95:.3g} -> assume nV, convert to μV by /1000"
else:
data_uV = data
msg = f"[UNIT] p95={p95:.3g} -> assume μV, no scaling"
p95_uV = float(np.percentile(np.abs(data_uV), 95))
warn = None
if p95_uV > 5000:
warn = f"[WARN] After scaling, p95 still large: {p95_uV:.3g} μV"
elif p95_uV < 0.1:
warn = f"[WARN] After scaling, p95 still small: {p95_uV:.3g} μV"
return data_uV, msg, warn
# ==========================
# mat 读取(支持 struct.data / electrode_name / electrode_xyz / sample_rate
# ==========================
def _unwrap_singleton(x):
while True:
if isinstance(x, np.ndarray):
if x.dtype == object and x.size == 1:
x = x.item()
continue
if x.size == 1 and x.ndim >= 1:
try:
x = x.reshape(-1)[0]
continue
except Exception:
pass
break
return x
def _try_get_struct_field(v, field_name="data"):
if hasattr(v, "_fieldnames") and field_name in getattr(v, "_fieldnames", []):
return getattr(v, field_name)
if isinstance(v, np.ndarray) and v.dtype.names and field_name in v.dtype.names:
try:
return v[field_name]
except Exception:
return None
return None
def _extract_electrode_names(st):
nf = _try_get_struct_field(st, "electrode_name")
if nf is None:
return None
nf = _unwrap_singleton(nf)
if isinstance(nf, (list, tuple)):
names = [str(x).strip() for x in nf]
return names if names else None
if isinstance(nf, np.ndarray):
flat = nf.reshape(-1)
names = [str(_unwrap_singleton(x)).strip() for x in flat]
return names if names else None
s = str(nf).strip()
return [s] if s else None
def _extract_sample_rate(st):
sr = _try_get_struct_field(st, "sample_rate")
if sr is None:
return None
sr = _unwrap_singleton(sr)
try:
return float(sr)
except Exception:
return None
def _extract_xyz(st):
xyz = _try_get_struct_field(st, "electrode_xyz")
if xyz is None:
return None
xyz = _unwrap_singleton(xyz)
try:
xyz = np.asarray(xyz, dtype=float)
if xyz.ndim == 2 and xyz.shape[1] == 3:
return xyz
if xyz.ndim == 2 and xyz.shape[0] == 3:
return xyz.T
return None
except Exception:
return None
def load_eeg_from_mat(mat_path: str):
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
candidates = []
st_for_meta = None
for k, v in mat.items():
if k.startswith("__"):
continue
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
candidates.append((k, v, None))
continue
data_field = _try_get_struct_field(v, "data")
if data_field is not None:
data_field = _unwrap_singleton(data_field)
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
candidates.append((f"{k}.data", data_field, v))
continue
if isinstance(v, np.ndarray) and v.dtype == object:
vv = _unwrap_singleton(v)
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
candidates.append((k, vv, None))
continue
data2 = _try_get_struct_field(vv, "data")
if data2 is not None:
data2 = _unwrap_singleton(data2)
if isinstance(data2, np.ndarray) and data2.ndim == 2:
candidates.append((f"{k}.data", data2, vv))
continue
if not candidates:
raise RuntimeError(f"mat 里没找到可用 EEG 二维矩阵或 struct.data{mat_path}")
def score(arr: np.ndarray) -> int:
s = 0
if 64 in arr.shape: s += 10
if 32 in arr.shape: s += 9
if 128 in arr.shape: s += 8
if 129 in arr.shape: s += 7
s += int(np.prod(arr.shape) // 100000)
return s
candidates.sort(key=lambda x: score(x[1]), reverse=True)
key, eeg, st = candidates[0]
st_for_meta = st
eeg = np.asarray(_unwrap_singleton(eeg), dtype=np.float32)
if eeg.ndim != 2:
raise RuntimeError(f"解析结果不是二维: key={key}, shape={eeg.shape}, file={mat_path}")
# 统一成 (T, C)
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
eeg = eeg.T
elif eeg.shape[1] in (32, 64, 128, 129):
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
eeg = eeg.T
fs = DEFAULT_FS
ch_names = None
xyz = None
if st_for_meta is not None:
fs2 = _extract_sample_rate(st_for_meta)
if fs2 is not None and fs2 > 1:
fs = float(fs2)
ch_names = _extract_electrode_names(st_for_meta)
xyz = _extract_xyz(st_for_meta)
eeg_uV, msg, warn = _auto_scale_to_uV(eeg)
print(msg)
if warn:
print(warn)
return eeg_uV.astype(np.float32), float(fs), ch_names, xyz
# ==========================
# 预测接口:导入 predict_hc_mdd
# ==========================
def _predict_label_by_model(model_path: str, mat_dir: str) -> str:
try:
from infer_pth import predict_hc_mdd
except Exception as e:
raise RuntimeError(
"无法导入 predict_hc_mdd请确保 pre.py 或 infer_pth.py 与本文件同目录)。\n"
f"原始错误: {e}"
)
try:
out = predict_hc_mdd(mat_dir, model_path)
except TypeError:
out = predict_hc_mdd(model_path, mat_dir)
label = str(out.get("pred_label", "")).strip().upper()
if label not in ("HC", "MDD"):
raise RuntimeError(f"predict_hc_mdd 返回 pred_label 非法: {label},原始返回: {out}")
return label
# ==========================
# 通道分区
# ==========================
def _norm_name(s: str) -> str:
return str(s).strip().upper().replace(" ", "")
def build_channel_index_map(ch_names, n_channels: int):
if not ch_names or len(ch_names) != n_channels:
return {}
return {_norm_name(nm): i for i, nm in enumerate(ch_names)}
def pick_indices_by_names(name_to_idx, names):
idx = []
for n in names:
nn = _norm_name(n)
if nn in name_to_idx:
idx.append(name_to_idx[nn])
return sorted(list(set(idx)))
def _fallback_region_indices(n_channels: int):
a = int(n_channels * 0.33)
b = int(n_channels * 0.66)
frontal = list(range(0, a))
central = list(range(a, b))
parietal = list(range(b, n_channels))
prefrontal = list(range(0, max(2, a // 2)))
posterior = list(range(b, n_channels))
left = [i for i in range(n_channels) if i % 2 == 0]
right = [i for i in range(n_channels) if i % 2 == 1]
return frontal, central, parietal, prefrontal, posterior, left, right
def get_region_indices(name_to_idx, n_channels: int):
if not name_to_idx:
return _fallback_region_indices(n_channels)
central_names = ["CZ","C1","C2","C3","C4","C5","C6","CP1","CP2","CP3","CP4","CP5","CP6","FC1","FC2","FC3","FC4","FC5","FC6"]
frontal_names = ["FZ","F1","F2","F3","F4","F5","F6","F7","F8","AF3","AF4","AF7","AF8","FPZ","FP1","FP2","FCZ"]
parietal_names = ["PZ","P1","P2","P3","P4","P5","P6","POZ","PO3","PO4","PO5","PO6","PO7","PO8","CPZ"]
prefrontal_names = ["FP1","FP2","FPZ","AF3","AF4","AF7","AF8"]
posterior_names = ["O1","O2","OZ","PO7","PO8","PO3","PO4","PZ","P3","P4","P1","P2"]
central = pick_indices_by_names(name_to_idx, central_names)
frontal = pick_indices_by_names(name_to_idx, frontal_names)
parietal = pick_indices_by_names(name_to_idx, parietal_names)
prefrontal = pick_indices_by_names(name_to_idx, prefrontal_names)
posterior = pick_indices_by_names(name_to_idx, posterior_names)
left_names = ["FP1","AF3","AF7","F3","F5","F7"]
right_names = ["FP2","AF4","AF8","F4","F6","F8"]
left = pick_indices_by_names(name_to_idx, left_names)
right = pick_indices_by_names(name_to_idx, right_names)
if not (central and frontal and parietal and prefrontal and posterior):
fb = _fallback_region_indices(n_channels)
frontal2, central2, parietal2, prefrontal2, posterior2, left2, right2 = fb
frontal = frontal if frontal else frontal2
central = central if central else central2
parietal = parietal if parietal else parietal2
prefrontal = prefrontal if prefrontal else prefrontal2
posterior = posterior if posterior else posterior2
left = left if left else left2
right = right if right else right2
return frontal, central, parietal, prefrontal, posterior, left, right
# ==========================
# Welch PSD + band power
# ==========================
def welch_psd(eeg_tc: np.ndarray, fs: float):
nperseg = min(1024, eeg_tc.shape[0])
if nperseg < 128:
nperseg = min(256, eeg_tc.shape[0])
freqs, pxx = signal.welch(
eeg_tc, fs=fs, nperseg=nperseg, noverlap=nperseg // 2,
axis=0, scaling="density",
)
return freqs, pxx
def band_power_from_psd(freqs, pxx_fc, band):
lo, hi = band
m = (freqs >= lo) & (freqs < hi)
if not np.any(m):
return np.zeros((pxx_fc.shape[1],), dtype=np.float32)
# 兼容处理numpy 2.0+ 推荐使用 trapezoid旧版本用 trapz
if hasattr(np, "trapezoid"):
return np.trapezoid(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
else:
return np.trapz(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
def region_mean_power(freqs, pxx_fc, idx, band) -> float:
if not idx:
return 0.0
pw = band_power_from_psd(freqs, pxx_fc, band)
return float(np.mean(pw[idx]))
def compute_iaf(freqs, pxx_fc, posterior_idx):
lo, hi = BANDS_METRICS["Alpha"]
m = (freqs >= lo) & (freqs <= hi)
if not np.any(m) or not posterior_idx:
return 0.0
spec = np.mean(pxx_fc[:, posterior_idx], axis=1)
sub = spec[m]
fsub = freqs[m]
return float(fsub[int(np.argmax(sub))])
# ==========================
# 图EEG波形、PSD
# ==========================
def plot_eeg_waveforms(data_uv_tc: np.ndarray, fs: float, ch_names, out_dir: str, seconds: int = 10):
"""
固定用 FIXED_EEG_IDXS 画 EEG.png按重要性排序
data_uv_tc: (T, C) μV
"""
T, C = data_uv_tc.shape
# 1) 过滤越界索引(避免你的数据通道数不足时报错)
idxs = [i for i in FIXED_EEG_IDXS if 0 <= i < C]
if len(idxs) < len(FIXED_EEG_IDXS):
missing = [i for i in FIXED_EEG_IDXS if not (0 <= i < C)]
print(f"[WARN] Some fixed EEG indices out of range (C={C}): {missing}")
if len(idxs) == 0:
raise RuntimeError(f"No valid indices in FIXED_EEG_IDXS for current data (C={C}).")
# 2) 通道显示
picked_names = []
for idx in idxs:
# 找 idx 在 FIXED_EEG_IDXS 的位置,用对应标签
pos = FIXED_EEG_IDXS.index(idx)
std_label = FIXED_EEG_LABELS[pos] if pos < len(FIXED_EEG_LABELS) else f"CH{idx}"
if ch_names and idx < len(ch_names):
picked_names.append(f"{std_label}")
else:
picked_names.append(std_label)
# 3) 截取前 seconds 秒
max_samples = int(min(T, seconds * fs))
x = np.arange(max_samples) / fs
fig_h = 1.4 * len(idxs) + 1
fig, axes = plt.subplots(len(idxs), 1, figsize=(10, fig_h), sharex=True)
if len(idxs) == 1:
axes = [axes]
# 4) 分位数定范围,避免尖峰撑爆
seg = data_uv_tc[:max_samples, idxs].T # (n_ch, samples)
lo = float(np.percentile(seg, 1))
hi = float(np.percentile(seg, 99))
m = max(abs(lo), abs(hi))
m = max(m, 50.0)
for ax, ch_idx, nm in zip(axes, idxs, picked_names):
y = data_uv_tc[:max_samples, ch_idx]
ax.plot(x, y, linewidth=1.2)
ax.set_ylabel("μV")
ax.set_title(nm, loc="left", fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_ylim(-m, m)
axes[-1].set_xlabel("Time (s)")
plt.tight_layout()
out_path = os.path.join(out_dir, "EEG.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] EEG waveform saved: {out_path}")
def plot_psd(eeg_uV_tc, fs, ch_names, out_dir):
C = eeg_uV_tc.shape[1]
chosen_idx = []
if ch_names:
mp = {n.upper(): i for i, n in enumerate(ch_names)}
for p in ["C3","C4","CZ"]:
if p in mp:
chosen_idx.append(mp[p])
if len(chosen_idx) < 3:
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
stds.sort(key=lambda x: x[1], reverse=True)
for i, _ in stds:
if i not in chosen_idx:
chosen_idx.append(i)
if len(chosen_idx) == 3:
break
chosen_name = [ch_names[i] for i in chosen_idx]
else:
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
stds.sort(key=lambda x: x[1], reverse=True)
chosen_idx = [i for i, _ in stds[:3]]
chosen_name = [f"CH{i}" for i in chosen_idx]
fig = plt.figure(figsize=(7.5, 4.8))
for idx, nm in zip(chosen_idx, chosen_name):
f, pxx = signal.welch(eeg_uV_tc[:, idx], fs=fs, nperseg=int(2*fs), noverlap=int(1*fs))
mask = (f >= PSD_FMIN) & (f <= PSD_FMAX)
p_db = 10 * np.log10(pxx[mask] + 1e-20)
plt.plot(f[mask], p_db, linewidth=1.8, label=nm)
plt.xlabel("Hz")
plt.ylabel("Power (dB)")
plt.title("PSD")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
out_path = os.path.join(out_dir, "psd.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] psd.png -> {out_path}")
# ==========================
# Topomap如果有 xyz
# ==========================
def build_mne_raw_from_uV(eeg_uV_tc, fs, ch_names, xyz):
C = eeg_uV_tc.shape[1]
if not ch_names:
ch_names = [f"CH{i+1}" for i in range(C)]
data_v_ct = eeg_uV_tc.T * 1e-6 # (C,T) V
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"] * C)
raw = mne.io.RawArray(data_v_ct, info, verbose=False)
if xyz is not None and isinstance(xyz, np.ndarray) and xyz.shape == (C, 3):
try:
ch_pos = {ch_names[i]: xyz[i, :] for i in range(C)}
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame="head")
raw.set_montage(montage, on_missing="ignore")
except Exception as e:
print(f"[WARN] set_montage failed (ignore): {e}")
else:
print("[WARN] electrode_xyz missing/invalid -> skip topomap")
return raw
def _raw_has_positions(raw):
try:
locs = np.array([ch["loc"][:3] for ch in raw.info["chs"]])
ok = np.isfinite(locs).all() and (np.linalg.norm(locs, axis=1) > 0).any()
return bool(ok)
except Exception:
return False
def compute_band_powers_for_topomap(raw, bands):
data = raw.get_data() # (C,T) V
fs = raw.info["sfreq"]
psds, freqs = mne.time_frequency.psd_array_welch(
data, sfreq=fs,
fmin=min(v[0] for v in bands.values()),
fmax=max(v[1] for v in bands.values()),
n_fft=int(2 * fs),
n_overlap=int(1 * fs),
average="mean",
verbose=False
)
out = {}
for k, (fmin, fmax) in bands.items():
idx = np.where((freqs >= fmin) & (freqs <= fmax))[0]
# 兼容处理numpy 2.0+ 推荐使用 trapezoid旧版本用 trapz
if hasattr(np, "trapezoid"):
bp = np.trapezoid(psds[:, idx], freqs[idx], axis=1) # (C,)
else:
bp = np.trapz(psds[:, idx], freqs[idx], axis=1) # (C,)
v = np.log10(bp + 1e-30)
v = v - np.mean(v)
out[k] = v
return out
def plot_average_topomap(raw, values, out_dir):
fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.6))
im, _ = mne.viz.plot_topomap(values, raw.info, axes=ax, show=False, contours=0,sphere=(0, 0, 0, 0.11))
ax.set_title("0.8-30 Hz", fontsize=12)
plt.colorbar(im, ax=ax, shrink=0.85)
plt.tight_layout()
out_path = os.path.join(out_dir, "average_topomap.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] average_topomap.png -> {out_path}")
def plot_band_topomaps(raw, band_values, out_dir):
order = [
("delta", "δ (0.8-3.9Hz)"),
("theta", "θ (4-7.9Hz)"),
("alpha", "α (8-12.9Hz)"),
("beta", "β (13-30Hz)"),
("broad", "0.8-30 Hz"),
]
fig, axes = plt.subplots(1, 5, figsize=(16, 4.2))
ims = []
for ax, (k, title) in zip(axes, order):
im, _ = mne.viz.plot_topomap(band_values[k], raw.info, axes=ax, show=False, contours=0,extrapolate='head',sphere=(0, 0, 0, 0.11))
ax.set_title(title, fontsize=11)
ims.append(im)
fig.subplots_adjust(left=0.02, right=0.85, top=0.88, bottom=0.05, wspace=0.35)
cax = fig.add_axes([0.87, 0.15, 0.015, 0.7])
fig.colorbar(ims[-1], cax=cax)
out_path = os.path.join(out_dir, "topomaps.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] topomaps.png -> {out_path}")
# ==========================
# 生成 ResultData.txt
# ==========================
def compute_and_save_txt(model_path, mat_dir, out_dir, eeg_uV_tc, fs, ch_names):
pred_label = _predict_label_by_model(model_path, mat_dir)
recommend = "" if pred_label == "MDD" else ""
T, C = eeg_uV_tc.shape
mp = build_channel_index_map(ch_names, C)
frontal_idx, central_idx, parietal_idx, prefrontal_idx, posterior_idx, left_idx, right_idx = \
get_region_indices(mp, C)
freqs, pxx = welch_psd(eeg_uV_tc, fs)
central_alpha = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Alpha"])
central_beta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Beta"])
frontal_alpha = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Alpha"])
frontal_beta = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Beta"])
par_alpha = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Alpha"])
par_beta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Beta"])
central_ab = (central_alpha / (central_beta + EPS)) if central_beta > 0 else 0.0
frontal_ab = (frontal_alpha / (frontal_beta + EPS)) if frontal_beta > 0 else 0.0
par_ab = (par_alpha / (par_beta + EPS)) if par_beta > 0 else 0.0
central_theta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Theta"])
par_theta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Theta"])
central_tb = (central_theta / (central_beta + EPS)) if central_beta > 0 else 0.0
par_tb = (par_theta / (par_beta + EPS)) if par_beta > 0 else 0.0
if not left_idx or not right_idx:
left_idx = [i for i in prefrontal_idx if (i % 2 == 0)]
right_idx = [i for i in prefrontal_idx if (i % 2 == 1)]
left_alpha = region_mean_power(freqs, pxx, left_idx, BANDS_METRICS["Alpha"])
right_alpha = region_mean_power(freqs, pxx, right_idx, BANDS_METRICS["Alpha"])
prefrontal_alpha_asym = float(np.log(right_alpha + EPS) - np.log(left_alpha + EPS))
iaf = compute_iaf(freqs, pxx, posterior_idx)
pre_td = region_mean_power(freqs, pxx, prefrontal_idx, (BANDS_METRICS["Delta"][0], BANDS_METRICS["Theta"][1]))
pre_total = region_mean_power(freqs, pxx, prefrontal_idx, TOTAL_POWER_BAND)
pre_td_rel = (pre_td / (pre_total + EPS)) * 100.0 if pre_total > 0 else 0.0
def f1(x): return f"{x:.1f}"
txt = (
f"中央区α/β波比值:{f1(central_ab)}\n"
f"额区α/β波比值:{f1(frontal_ab)}\n"
f"顶区α/β波比值:{f1(par_ab)}\n"
f"中央区θ/β波比值:{f1(central_tb)}\n"
f"顶区θ/β波比值:{f1(par_tb)}\n"
f"前额叶α波不对称性:{f1(prefrontal_alpha_asym)}\n"
f"个体化α峰值频率:{f1(iaf)}\n"
f"前额叶θ+δ波功率:{f1(pre_td_rel)}\n"
f"是否推荐治疗:{recommend}\n"
)
out_path = os.path.join(out_dir, "ResultData.txt")
with open(out_path, "w", encoding="utf-8") as f:
f.write(txt)
print(f"[OK] ResultData.txt -> {out_path}")
# ==========================
# 一个函数一次性跑完txt + 图片)
# ==========================
def run_all(model_path: str, mat_dir: str, out_root: str, seconds: int = EEG_PLOT_SECONDS):
# 1) 选第一个 mat
if not os.path.exists(mat_dir):
raise RuntimeError(f"输入目录不存在: {mat_dir}")
mats = [f for f in os.listdir(mat_dir) if f.lower().endswith(".mat")]
if not mats:
raise RuntimeError(f"mat_dir 下找不到 .mat: {mat_dir}")
mats.sort()
mat_file = os.path.join(mat_dir, mats[0])
print(f"[INFO] Found mat: {mat_file}")
# 2) 创建输出目录
out_dir = ensure_outdir(out_root)
print(f"[INFO] Output dir: {out_dir}")
# --- 总是进行预处理 (默认模式) ---
print("[INFO] Mode: Raw Data (Default). Running preprocessing...")
temp_dir = os.path.join(out_dir, "temp_preprocessed")
mat_file = preprocess_mat_file(mat_file, temp_dir)
# 更新 mat_dir 指向临时目录(为了传给 compute_and_save_txt 里的 predict 接口)
mat_dir = temp_dir
# 3) 读 EEGμV
eeg_uV_tc, fs, ch_names, xyz = load_eeg_from_mat(mat_file)
print(f"[INFO] eeg shape(T,C)={eeg_uV_tc.shape}, fs={fs}")
# 5) 画图PSD + EEG
plot_psd(eeg_uV_tc, fs, ch_names, out_dir)
plot_eeg_waveforms(eeg_uV_tc, fs, ch_names, out_dir, seconds=seconds)
# 6) topomap有 xyz 才画)
try:
raw = build_mne_raw_from_uV(eeg_uV_tc, fs, ch_names, xyz)
if _raw_has_positions(raw):
band_vals = compute_band_powers_for_topomap(raw, BANDS_TOPOMAP)
plot_average_topomap(raw, band_vals["broad"], out_dir)
plot_band_topomaps(raw, band_vals, out_dir)
else:
print("[WARN] No valid positions -> skip topomap.")
except Exception as e:
print(f"[WARN] topomap failed -> skip. reason: {e}")
# 4) 指标写 txt
compute_and_save_txt(model_path, mat_dir, out_dir, eeg_uV_tc, fs, ch_names)
print("[DONE] txt + figures generated.")
return out_dir
if __name__ == "__main__":
import multiprocessing
multiprocessing.freeze_support()
import argparse
import sys
# 1. 路径锚定:获取资源绝对路径
def get_resource_path(relative_path):
"""
获取资源的绝对路径。
策略优先在当前执行目录EXE所在目录寻找。
这适用于“绿色软件”模式即资源文件model/raw_data直接放在EXE旁边。
"""
if getattr(sys, 'frozen', False):
# PyInstaller 打包后的 EXE 所在目录
base_path = os.path.dirname(sys.executable)
else:
# 开发环境:当前脚本所在目录
base_path = os.path.dirname(os.path.abspath(__file__))
return os.path.join(base_path, relative_path)
# 设置默认路径
DEFAULT_MODEL = get_resource_path(os.path.join("model", "Model_1.pth"))
# 这里我们保持 mat_dir 和 out_root 相对于 EXE 所在目录(或当前工作目录)
if getattr(sys, 'frozen', False):
EXE_DIR = os.path.dirname(sys.executable)
else:
EXE_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_MAT = os.path.join(EXE_DIR, "raw_data")
DEFAULT_OUT = os.path.join(EXE_DIR, "out")
# 2. 解析命令行参数
parser = argparse.ArgumentParser(description="EEG Depression Assessment Algorithm Integration")
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL, help="模型文件的路径 (.pth)")
parser.add_argument("--mat_dir", type=str, default=DEFAULT_MAT, help="输入文件夹路径 (包含原始EEG .mat)")
parser.add_argument("--out_root", type=str, default=DEFAULT_OUT, help="结果输出的根目录")
parser.add_argument("--seconds", type=int, default=10, help="画波形图截取的秒数")
args = parser.parse_args()
# 3. 检查关键路径
if not os.path.exists(args.mat_dir):
print(f"[WARN] 输入文件夹不存在: {args.mat_dir}")
if not os.path.exists(args.model_path):
print(f"[WARN] 模型文件不存在: {args.model_path}")
# 4. 执行主流程
print(f"[*] 运行配置:")
print(f" - Model : {args.model_path}")
print(f" - Input : {args.mat_dir}")
print(f" - Output: {args.out_root}")
print(f" - Mode : RAW (Auto Preprocess)")
run_all(args.model_path, args.mat_dir, args.out_root, seconds=args.seconds)