original push
This commit is contained in:
909
algorithm_V0/algorithm_fromXjtu/runDecoder.py
Normal file
909
algorithm_V0/algorithm_fromXjtu/runDecoder.py
Normal file
@@ -0,0 +1,909 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
"""
|
||||
run_metrics_and_figs.py
|
||||
|
||||
1) 自动读取 mat_dir 中排序后的第一个 .mat
|
||||
2) 调用模型预测(HC/MDD)并写 ResultData.txt
|
||||
3) 同时保存图片:EEG.png / psd.png / average_topomap.png / topomaps.png
|
||||
|
||||
"""
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import scipy.io
|
||||
import scipy.signal as signal
|
||||
import matplotlib.pyplot as plt
|
||||
import mne
|
||||
from mne.preprocessing import ICA
|
||||
|
||||
# ==========================
|
||||
# Config
|
||||
# ==========================
|
||||
PREPROCESS_BANDPASS = (0.8, 30.0)
|
||||
PREPROCESS_NOTCH = [50, 100]
|
||||
PREPROCESS_ICA_N = 0.99
|
||||
PREPROCESS_ICA_SEED = 97
|
||||
PREPROCESS_APPLY_AVG_REF = True
|
||||
PREPROCESS_BAD_PTP_UV = 350.0 # 坏段阈值 (μV)
|
||||
|
||||
DEFAULT_FS = 250.0
|
||||
EEG_PLOT_SECONDS = 10
|
||||
PSD_FMIN, PSD_FMAX = 0.8, 45.0
|
||||
EPS = 1e-12
|
||||
FIXED_EEG_IDXS = [23, 47, 39, 6, 2, 21, 35, 57] # 0-based index, 按重要性排序
|
||||
FIXED_EEG_LABELS = ["C5", "O1", "TP7", "FPZ", "PO6", "P4", "AF7", "AF3"]
|
||||
|
||||
BANDS_METRICS = {
|
||||
"Delta": (1.0, 4.0),
|
||||
"Theta": (4.0, 8.0),
|
||||
"Alpha": (8.0, 13.0),
|
||||
"Beta": (13.0, 30.0),
|
||||
}
|
||||
TOTAL_POWER_BAND = (1.0, 50.0)
|
||||
|
||||
BANDS_TOPOMAP = {
|
||||
"delta": (0.8, 3.9),
|
||||
"theta": (4.0, 7.9),
|
||||
"alpha": (8.0, 12.9),
|
||||
"beta": (13.0, 30.0),
|
||||
"broad": (0.8, 30.0),
|
||||
}
|
||||
|
||||
|
||||
# ==========================
|
||||
# 预处理逻辑
|
||||
# ==========================
|
||||
def annotate_bad_segments(raw, peak_to_peak_uv=250.0):
|
||||
"""
|
||||
简单坏段检测:按固定窗口计算峰峰值,超过阈值标为 bad。
|
||||
"""
|
||||
peak_to_peak_v = peak_to_peak_uv * 1e-6
|
||||
win = int(raw.info["sfreq"] * 1.0)
|
||||
step = int(raw.info["sfreq"] * 0.5)
|
||||
data = raw.get_data()
|
||||
n_times = data.shape[1]
|
||||
onsets = []
|
||||
durations = []
|
||||
descriptions = []
|
||||
|
||||
for start in range(0, n_times - win, step):
|
||||
seg = data[:, start:start + win]
|
||||
ptp = np.ptp(seg, axis=1)
|
||||
if np.any(ptp > peak_to_peak_v):
|
||||
onsets.append(start / raw.info["sfreq"])
|
||||
durations.append(win / raw.info["sfreq"])
|
||||
descriptions.append("BAD_PTP")
|
||||
|
||||
if len(onsets) > 0:
|
||||
ann = mne.Annotations(onset=onsets, duration=durations, description=descriptions)
|
||||
raw.set_annotations(ann)
|
||||
print(f"[INFO] Annotated bad segments: {len(onsets)} windows")
|
||||
else:
|
||||
print("[INFO] No bad segments detected by PTP rule")
|
||||
|
||||
|
||||
def run_preprocess_on_raw(raw: mne.io.RawArray) -> mne.io.RawArray:
|
||||
"""
|
||||
核心预处理:滤波 + 平均参考 + 坏段标注 + ICA
|
||||
"""
|
||||
# 1) 滤波
|
||||
raw.filter(PREPROCESS_BANDPASS[0], PREPROCESS_BANDPASS[1], fir_design="firwin", verbose=False)
|
||||
raw.notch_filter(PREPROCESS_NOTCH, fir_design="firwin", verbose=False)
|
||||
|
||||
# 2) 平均参考
|
||||
if PREPROCESS_APPLY_AVG_REF:
|
||||
raw.set_eeg_reference("average", verbose=False)
|
||||
|
||||
# 3) 坏段标注
|
||||
annotate_bad_segments(raw, peak_to_peak_uv=PREPROCESS_BAD_PTP_UV)
|
||||
|
||||
# 4) ICA
|
||||
ica = ICA(
|
||||
n_components=PREPROCESS_ICA_N,
|
||||
random_state=PREPROCESS_ICA_SEED,
|
||||
max_iter=800,
|
||||
method="fastica"
|
||||
)
|
||||
ica.fit(raw, reject_by_annotation=True, verbose=False)
|
||||
|
||||
try:
|
||||
eog_inds, _ = ica.find_bads_eog(raw, verbose=False)
|
||||
if eog_inds:
|
||||
ica.exclude.extend(eog_inds)
|
||||
print(f"[INFO] ICA exclude EOG comps: {eog_inds}")
|
||||
except Exception as e:
|
||||
print(f"[WARN] ICA find_bads_eog skipped: {e}")
|
||||
|
||||
raw_clean = ica.apply(raw.copy(), verbose=False)
|
||||
return raw_clean
|
||||
|
||||
def preprocess_mat_file(src_mat_path: str, temp_out_dir: str) -> str:
|
||||
"""
|
||||
读取原始mat -> 预处理 -> 保存到 temp_out_dir -> 返回新路径
|
||||
"""
|
||||
os.makedirs(temp_out_dir, exist_ok=True)
|
||||
|
||||
# 1. 读原始 mat
|
||||
# 注意:这里我们只要数据部分转成 MNE Raw,然后处理,再存回
|
||||
# 复用现有的 load_eeg_from_mat 拿到 ndarray
|
||||
eeg_uV, fs, ch_names, xyz = load_eeg_from_mat(src_mat_path)
|
||||
|
||||
# 转 MNE (注意单位:uV -> V)
|
||||
if not ch_names:
|
||||
ch_names = [f"CH{i+1}" for i in range(eeg_uV.shape[1])]
|
||||
|
||||
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"] * len(ch_names))
|
||||
raw = mne.io.RawArray(eeg_uV.T * 1e-6, info, verbose=False)
|
||||
|
||||
if xyz is not None and isinstance(xyz, np.ndarray):
|
||||
# 尝试设 montage(虽然对滤波不关键,但尽量保留信息)
|
||||
try:
|
||||
ch_pos = {ch_names[i]: xyz[i, :] for i in range(len(ch_names))}
|
||||
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame="head")
|
||||
raw.set_montage(montage, on_missing="ignore")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. 执行预处理
|
||||
print(f"[INFO] Start preprocessing: {src_mat_path}")
|
||||
raw_clean = run_preprocess_on_raw(raw)
|
||||
|
||||
# 3. 存回 .mat (保持结构兼容,以便后续 run_all 读取)
|
||||
# 这里我们需要读取原始 mat 的结构体,把 data 替换掉
|
||||
try:
|
||||
mat_struct = scipy.io.loadmat(src_mat_path, struct_as_record=False, squeeze_me=True)
|
||||
if "eeg" in mat_struct:
|
||||
eeg_obj = mat_struct["eeg"]
|
||||
# 替换数据:MNE (V) -> uV -> (T, C)
|
||||
clean_data_uV = (raw_clean.get_data() * 1e6).T
|
||||
eeg_obj.data = clean_data_uV
|
||||
|
||||
base_name = os.path.basename(src_mat_path)
|
||||
new_path = os.path.join(temp_out_dir, base_name)
|
||||
scipy.io.savemat(new_path, {"eeg": eeg_obj}, do_compression=True)
|
||||
print(f"[INFO] Preprocessed file saved to: {new_path}")
|
||||
return new_path
|
||||
except Exception as e:
|
||||
print(f"[WARN] Failed to preserve original struct structure: {e}")
|
||||
|
||||
# Fallback: 如果读原始结构失败,就存一个简单的 mat
|
||||
clean_data_uV = (raw_clean.get_data() * 1e6).T
|
||||
out_dict = {
|
||||
"eeg": {
|
||||
"data": clean_data_uV,
|
||||
"sample_rate": fs,
|
||||
"electrode_name": ch_names,
|
||||
"electrode_xyz": xyz if xyz is not None else []
|
||||
}
|
||||
}
|
||||
base_name = os.path.basename(src_mat_path)
|
||||
new_path = os.path.join(temp_out_dir, base_name)
|
||||
scipy.io.savemat(new_path, out_dict, do_compression=True)
|
||||
print(f"[INFO] Preprocessed file saved (fallback mode) to: {new_path}")
|
||||
return new_path
|
||||
|
||||
|
||||
# ==========================
|
||||
# 输出目录
|
||||
# ==========================
|
||||
def ensure_outdir(out_root: str) -> str:
|
||||
"""
|
||||
确保输出目录存在,并清空除 ResultData.txt 之外的旧文件。
|
||||
不再创建 timestamp 子文件夹,直接输出到 out_root。
|
||||
"""
|
||||
if os.path.exists(out_root):
|
||||
# 清空目录,但保留 ResultData.txt
|
||||
for filename in os.listdir(out_root):
|
||||
if filename == "ResultData.txt":
|
||||
continue
|
||||
file_path = os.path.join(out_root, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
print(f"[WARN] Failed to delete {file_path}. Reason: {e}")
|
||||
else:
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
|
||||
return out_root
|
||||
|
||||
|
||||
# ==========================
|
||||
# 单位自动识别:统一到 μV
|
||||
# ==========================
|
||||
def _auto_scale_to_uV(data_nt_nc: np.ndarray):
|
||||
data = np.asarray(data_nt_nc)
|
||||
p95 = float(np.percentile(np.abs(data), 95))
|
||||
|
||||
if p95 <= 0.5:
|
||||
data_uV = data * 1e6
|
||||
msg = f"[UNIT] p95={p95:.3g} -> assume V, convert to μV by *1e6"
|
||||
elif p95 > 5000:
|
||||
data_uV = data * 1e-3
|
||||
msg = f"[UNIT] p95={p95:.3g} -> assume nV, convert to μV by /1000"
|
||||
else:
|
||||
data_uV = data
|
||||
msg = f"[UNIT] p95={p95:.3g} -> assume μV, no scaling"
|
||||
|
||||
p95_uV = float(np.percentile(np.abs(data_uV), 95))
|
||||
warn = None
|
||||
if p95_uV > 5000:
|
||||
warn = f"[WARN] After scaling, p95 still large: {p95_uV:.3g} μV"
|
||||
elif p95_uV < 0.1:
|
||||
warn = f"[WARN] After scaling, p95 still small: {p95_uV:.3g} μV"
|
||||
|
||||
return data_uV, msg, warn
|
||||
|
||||
|
||||
# ==========================
|
||||
# mat 读取(支持 struct.data / electrode_name / electrode_xyz / sample_rate)
|
||||
# ==========================
|
||||
def _unwrap_singleton(x):
|
||||
while True:
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype == object and x.size == 1:
|
||||
x = x.item()
|
||||
continue
|
||||
if x.size == 1 and x.ndim >= 1:
|
||||
try:
|
||||
x = x.reshape(-1)[0]
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
return x
|
||||
|
||||
|
||||
def _try_get_struct_field(v, field_name="data"):
|
||||
if hasattr(v, "_fieldnames") and field_name in getattr(v, "_fieldnames", []):
|
||||
return getattr(v, field_name)
|
||||
if isinstance(v, np.ndarray) and v.dtype.names and field_name in v.dtype.names:
|
||||
try:
|
||||
return v[field_name]
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _extract_electrode_names(st):
|
||||
nf = _try_get_struct_field(st, "electrode_name")
|
||||
if nf is None:
|
||||
return None
|
||||
nf = _unwrap_singleton(nf)
|
||||
if isinstance(nf, (list, tuple)):
|
||||
names = [str(x).strip() for x in nf]
|
||||
return names if names else None
|
||||
if isinstance(nf, np.ndarray):
|
||||
flat = nf.reshape(-1)
|
||||
names = [str(_unwrap_singleton(x)).strip() for x in flat]
|
||||
return names if names else None
|
||||
s = str(nf).strip()
|
||||
return [s] if s else None
|
||||
|
||||
|
||||
def _extract_sample_rate(st):
|
||||
sr = _try_get_struct_field(st, "sample_rate")
|
||||
if sr is None:
|
||||
return None
|
||||
sr = _unwrap_singleton(sr)
|
||||
try:
|
||||
return float(sr)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_xyz(st):
|
||||
xyz = _try_get_struct_field(st, "electrode_xyz")
|
||||
if xyz is None:
|
||||
return None
|
||||
xyz = _unwrap_singleton(xyz)
|
||||
try:
|
||||
xyz = np.asarray(xyz, dtype=float)
|
||||
if xyz.ndim == 2 and xyz.shape[1] == 3:
|
||||
return xyz
|
||||
if xyz.ndim == 2 and xyz.shape[0] == 3:
|
||||
return xyz.T
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def load_eeg_from_mat(mat_path: str):
|
||||
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
|
||||
|
||||
candidates = []
|
||||
st_for_meta = None
|
||||
|
||||
for k, v in mat.items():
|
||||
if k.startswith("__"):
|
||||
continue
|
||||
|
||||
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
|
||||
candidates.append((k, v, None))
|
||||
continue
|
||||
|
||||
data_field = _try_get_struct_field(v, "data")
|
||||
if data_field is not None:
|
||||
data_field = _unwrap_singleton(data_field)
|
||||
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
|
||||
candidates.append((f"{k}.data", data_field, v))
|
||||
continue
|
||||
|
||||
if isinstance(v, np.ndarray) and v.dtype == object:
|
||||
vv = _unwrap_singleton(v)
|
||||
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
|
||||
candidates.append((k, vv, None))
|
||||
continue
|
||||
data2 = _try_get_struct_field(vv, "data")
|
||||
if data2 is not None:
|
||||
data2 = _unwrap_singleton(data2)
|
||||
if isinstance(data2, np.ndarray) and data2.ndim == 2:
|
||||
candidates.append((f"{k}.data", data2, vv))
|
||||
continue
|
||||
|
||||
if not candidates:
|
||||
raise RuntimeError(f"mat 里没找到可用 EEG 二维矩阵或 struct.data:{mat_path}")
|
||||
|
||||
def score(arr: np.ndarray) -> int:
|
||||
s = 0
|
||||
if 64 in arr.shape: s += 10
|
||||
if 32 in arr.shape: s += 9
|
||||
if 128 in arr.shape: s += 8
|
||||
if 129 in arr.shape: s += 7
|
||||
s += int(np.prod(arr.shape) // 100000)
|
||||
return s
|
||||
|
||||
candidates.sort(key=lambda x: score(x[1]), reverse=True)
|
||||
key, eeg, st = candidates[0]
|
||||
st_for_meta = st
|
||||
|
||||
eeg = np.asarray(_unwrap_singleton(eeg), dtype=np.float32)
|
||||
if eeg.ndim != 2:
|
||||
raise RuntimeError(f"解析结果不是二维: key={key}, shape={eeg.shape}, file={mat_path}")
|
||||
|
||||
# 统一成 (T, C)
|
||||
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
|
||||
eeg = eeg.T
|
||||
elif eeg.shape[1] in (32, 64, 128, 129):
|
||||
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
|
||||
eeg = eeg.T
|
||||
|
||||
fs = DEFAULT_FS
|
||||
ch_names = None
|
||||
xyz = None
|
||||
if st_for_meta is not None:
|
||||
fs2 = _extract_sample_rate(st_for_meta)
|
||||
if fs2 is not None and fs2 > 1:
|
||||
fs = float(fs2)
|
||||
ch_names = _extract_electrode_names(st_for_meta)
|
||||
xyz = _extract_xyz(st_for_meta)
|
||||
|
||||
eeg_uV, msg, warn = _auto_scale_to_uV(eeg)
|
||||
print(msg)
|
||||
if warn:
|
||||
print(warn)
|
||||
|
||||
return eeg_uV.astype(np.float32), float(fs), ch_names, xyz
|
||||
|
||||
|
||||
# ==========================
|
||||
# 预测接口:导入 predict_hc_mdd
|
||||
# ==========================
|
||||
def _predict_label_by_model(model_path: str, mat_dir: str) -> str:
|
||||
try:
|
||||
from infer_pth import predict_hc_mdd
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"无法导入 predict_hc_mdd(请确保 pre.py 或 infer_pth.py 与本文件同目录)。\n"
|
||||
f"原始错误: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
out = predict_hc_mdd(mat_dir, model_path)
|
||||
except TypeError:
|
||||
out = predict_hc_mdd(model_path, mat_dir)
|
||||
|
||||
label = str(out.get("pred_label", "")).strip().upper()
|
||||
if label not in ("HC", "MDD"):
|
||||
raise RuntimeError(f"predict_hc_mdd 返回 pred_label 非法: {label},原始返回: {out}")
|
||||
return label
|
||||
|
||||
|
||||
# ==========================
|
||||
# 通道分区
|
||||
# ==========================
|
||||
def _norm_name(s: str) -> str:
|
||||
return str(s).strip().upper().replace(" ", "")
|
||||
|
||||
|
||||
def build_channel_index_map(ch_names, n_channels: int):
|
||||
if not ch_names or len(ch_names) != n_channels:
|
||||
return {}
|
||||
return {_norm_name(nm): i for i, nm in enumerate(ch_names)}
|
||||
|
||||
|
||||
def pick_indices_by_names(name_to_idx, names):
|
||||
idx = []
|
||||
for n in names:
|
||||
nn = _norm_name(n)
|
||||
if nn in name_to_idx:
|
||||
idx.append(name_to_idx[nn])
|
||||
return sorted(list(set(idx)))
|
||||
|
||||
|
||||
def _fallback_region_indices(n_channels: int):
|
||||
a = int(n_channels * 0.33)
|
||||
b = int(n_channels * 0.66)
|
||||
frontal = list(range(0, a))
|
||||
central = list(range(a, b))
|
||||
parietal = list(range(b, n_channels))
|
||||
prefrontal = list(range(0, max(2, a // 2)))
|
||||
posterior = list(range(b, n_channels))
|
||||
left = [i for i in range(n_channels) if i % 2 == 0]
|
||||
right = [i for i in range(n_channels) if i % 2 == 1]
|
||||
return frontal, central, parietal, prefrontal, posterior, left, right
|
||||
|
||||
|
||||
def get_region_indices(name_to_idx, n_channels: int):
|
||||
if not name_to_idx:
|
||||
return _fallback_region_indices(n_channels)
|
||||
|
||||
central_names = ["CZ","C1","C2","C3","C4","C5","C6","CP1","CP2","CP3","CP4","CP5","CP6","FC1","FC2","FC3","FC4","FC5","FC6"]
|
||||
frontal_names = ["FZ","F1","F2","F3","F4","F5","F6","F7","F8","AF3","AF4","AF7","AF8","FPZ","FP1","FP2","FCZ"]
|
||||
parietal_names = ["PZ","P1","P2","P3","P4","P5","P6","POZ","PO3","PO4","PO5","PO6","PO7","PO8","CPZ"]
|
||||
prefrontal_names = ["FP1","FP2","FPZ","AF3","AF4","AF7","AF8"]
|
||||
posterior_names = ["O1","O2","OZ","PO7","PO8","PO3","PO4","PZ","P3","P4","P1","P2"]
|
||||
|
||||
central = pick_indices_by_names(name_to_idx, central_names)
|
||||
frontal = pick_indices_by_names(name_to_idx, frontal_names)
|
||||
parietal = pick_indices_by_names(name_to_idx, parietal_names)
|
||||
prefrontal = pick_indices_by_names(name_to_idx, prefrontal_names)
|
||||
posterior = pick_indices_by_names(name_to_idx, posterior_names)
|
||||
|
||||
left_names = ["FP1","AF3","AF7","F3","F5","F7"]
|
||||
right_names = ["FP2","AF4","AF8","F4","F6","F8"]
|
||||
left = pick_indices_by_names(name_to_idx, left_names)
|
||||
right = pick_indices_by_names(name_to_idx, right_names)
|
||||
|
||||
if not (central and frontal and parietal and prefrontal and posterior):
|
||||
fb = _fallback_region_indices(n_channels)
|
||||
frontal2, central2, parietal2, prefrontal2, posterior2, left2, right2 = fb
|
||||
frontal = frontal if frontal else frontal2
|
||||
central = central if central else central2
|
||||
parietal = parietal if parietal else parietal2
|
||||
prefrontal = prefrontal if prefrontal else prefrontal2
|
||||
posterior = posterior if posterior else posterior2
|
||||
left = left if left else left2
|
||||
right = right if right else right2
|
||||
|
||||
return frontal, central, parietal, prefrontal, posterior, left, right
|
||||
|
||||
|
||||
# ==========================
|
||||
# Welch PSD + band power
|
||||
# ==========================
|
||||
def welch_psd(eeg_tc: np.ndarray, fs: float):
|
||||
nperseg = min(1024, eeg_tc.shape[0])
|
||||
if nperseg < 128:
|
||||
nperseg = min(256, eeg_tc.shape[0])
|
||||
freqs, pxx = signal.welch(
|
||||
eeg_tc, fs=fs, nperseg=nperseg, noverlap=nperseg // 2,
|
||||
axis=0, scaling="density",
|
||||
)
|
||||
return freqs, pxx
|
||||
|
||||
|
||||
def band_power_from_psd(freqs, pxx_fc, band):
|
||||
lo, hi = band
|
||||
m = (freqs >= lo) & (freqs < hi)
|
||||
if not np.any(m):
|
||||
return np.zeros((pxx_fc.shape[1],), dtype=np.float32)
|
||||
|
||||
# 兼容处理:numpy 2.0+ 推荐使用 trapezoid,旧版本用 trapz
|
||||
if hasattr(np, "trapezoid"):
|
||||
return np.trapezoid(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
|
||||
else:
|
||||
return np.trapz(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
|
||||
|
||||
|
||||
def region_mean_power(freqs, pxx_fc, idx, band) -> float:
|
||||
if not idx:
|
||||
return 0.0
|
||||
pw = band_power_from_psd(freqs, pxx_fc, band)
|
||||
return float(np.mean(pw[idx]))
|
||||
|
||||
|
||||
def compute_iaf(freqs, pxx_fc, posterior_idx):
|
||||
lo, hi = BANDS_METRICS["Alpha"]
|
||||
m = (freqs >= lo) & (freqs <= hi)
|
||||
if not np.any(m) or not posterior_idx:
|
||||
return 0.0
|
||||
spec = np.mean(pxx_fc[:, posterior_idx], axis=1)
|
||||
sub = spec[m]
|
||||
fsub = freqs[m]
|
||||
return float(fsub[int(np.argmax(sub))])
|
||||
|
||||
|
||||
# ==========================
|
||||
# 图:EEG波形、PSD
|
||||
# ==========================
|
||||
def plot_eeg_waveforms(data_uv_tc: np.ndarray, fs: float, ch_names, out_dir: str, seconds: int = 10):
|
||||
"""
|
||||
固定用 FIXED_EEG_IDXS 画 EEG.png(按重要性排序)
|
||||
data_uv_tc: (T, C) μV
|
||||
"""
|
||||
T, C = data_uv_tc.shape
|
||||
|
||||
# 1) 过滤越界索引(避免你的数据通道数不足时报错)
|
||||
idxs = [i for i in FIXED_EEG_IDXS if 0 <= i < C]
|
||||
if len(idxs) < len(FIXED_EEG_IDXS):
|
||||
missing = [i for i in FIXED_EEG_IDXS if not (0 <= i < C)]
|
||||
print(f"[WARN] Some fixed EEG indices out of range (C={C}): {missing}")
|
||||
|
||||
if len(idxs) == 0:
|
||||
raise RuntimeError(f"No valid indices in FIXED_EEG_IDXS for current data (C={C}).")
|
||||
|
||||
# 2) 通道显示
|
||||
picked_names = []
|
||||
for idx in idxs:
|
||||
# 找 idx 在 FIXED_EEG_IDXS 的位置,用对应标签
|
||||
pos = FIXED_EEG_IDXS.index(idx)
|
||||
std_label = FIXED_EEG_LABELS[pos] if pos < len(FIXED_EEG_LABELS) else f"CH{idx}"
|
||||
if ch_names and idx < len(ch_names):
|
||||
picked_names.append(f"{std_label}")
|
||||
else:
|
||||
picked_names.append(std_label)
|
||||
|
||||
# 3) 截取前 seconds 秒
|
||||
max_samples = int(min(T, seconds * fs))
|
||||
x = np.arange(max_samples) / fs
|
||||
|
||||
fig_h = 1.4 * len(idxs) + 1
|
||||
fig, axes = plt.subplots(len(idxs), 1, figsize=(10, fig_h), sharex=True)
|
||||
if len(idxs) == 1:
|
||||
axes = [axes]
|
||||
|
||||
# 4) 分位数定范围,避免尖峰撑爆
|
||||
seg = data_uv_tc[:max_samples, idxs].T # (n_ch, samples)
|
||||
lo = float(np.percentile(seg, 1))
|
||||
hi = float(np.percentile(seg, 99))
|
||||
m = max(abs(lo), abs(hi))
|
||||
m = max(m, 50.0)
|
||||
|
||||
for ax, ch_idx, nm in zip(axes, idxs, picked_names):
|
||||
y = data_uv_tc[:max_samples, ch_idx]
|
||||
ax.plot(x, y, linewidth=1.2)
|
||||
ax.set_ylabel("μV")
|
||||
ax.set_title(nm, loc="left", fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_ylim(-m, m)
|
||||
|
||||
axes[-1].set_xlabel("Time (s)")
|
||||
plt.tight_layout()
|
||||
|
||||
out_path = os.path.join(out_dir, "EEG.png")
|
||||
plt.savefig(out_path, dpi=200)
|
||||
plt.close(fig)
|
||||
print(f"[OK] EEG waveform saved: {out_path}")
|
||||
|
||||
|
||||
|
||||
def plot_psd(eeg_uV_tc, fs, ch_names, out_dir):
|
||||
C = eeg_uV_tc.shape[1]
|
||||
chosen_idx = []
|
||||
|
||||
if ch_names:
|
||||
mp = {n.upper(): i for i, n in enumerate(ch_names)}
|
||||
for p in ["C3","C4","CZ"]:
|
||||
if p in mp:
|
||||
chosen_idx.append(mp[p])
|
||||
if len(chosen_idx) < 3:
|
||||
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
|
||||
stds.sort(key=lambda x: x[1], reverse=True)
|
||||
for i, _ in stds:
|
||||
if i not in chosen_idx:
|
||||
chosen_idx.append(i)
|
||||
if len(chosen_idx) == 3:
|
||||
break
|
||||
chosen_name = [ch_names[i] for i in chosen_idx]
|
||||
else:
|
||||
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
|
||||
stds.sort(key=lambda x: x[1], reverse=True)
|
||||
chosen_idx = [i for i, _ in stds[:3]]
|
||||
chosen_name = [f"CH{i}" for i in chosen_idx]
|
||||
|
||||
fig = plt.figure(figsize=(7.5, 4.8))
|
||||
for idx, nm in zip(chosen_idx, chosen_name):
|
||||
f, pxx = signal.welch(eeg_uV_tc[:, idx], fs=fs, nperseg=int(2*fs), noverlap=int(1*fs))
|
||||
mask = (f >= PSD_FMIN) & (f <= PSD_FMAX)
|
||||
p_db = 10 * np.log10(pxx[mask] + 1e-20)
|
||||
plt.plot(f[mask], p_db, linewidth=1.8, label=nm)
|
||||
|
||||
plt.xlabel("Hz")
|
||||
plt.ylabel("Power (dB)")
|
||||
plt.title("PSD")
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
out_path = os.path.join(out_dir, "psd.png")
|
||||
plt.savefig(out_path, dpi=200)
|
||||
plt.close(fig)
|
||||
print(f"[OK] psd.png -> {out_path}")
|
||||
|
||||
|
||||
# ==========================
|
||||
# Topomap(如果有 xyz)
|
||||
# ==========================
|
||||
def build_mne_raw_from_uV(eeg_uV_tc, fs, ch_names, xyz):
|
||||
C = eeg_uV_tc.shape[1]
|
||||
if not ch_names:
|
||||
ch_names = [f"CH{i+1}" for i in range(C)]
|
||||
data_v_ct = eeg_uV_tc.T * 1e-6 # (C,T) V
|
||||
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"] * C)
|
||||
raw = mne.io.RawArray(data_v_ct, info, verbose=False)
|
||||
|
||||
if xyz is not None and isinstance(xyz, np.ndarray) and xyz.shape == (C, 3):
|
||||
try:
|
||||
ch_pos = {ch_names[i]: xyz[i, :] for i in range(C)}
|
||||
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame="head")
|
||||
raw.set_montage(montage, on_missing="ignore")
|
||||
except Exception as e:
|
||||
print(f"[WARN] set_montage failed (ignore): {e}")
|
||||
else:
|
||||
print("[WARN] electrode_xyz missing/invalid -> skip topomap")
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def _raw_has_positions(raw):
|
||||
try:
|
||||
locs = np.array([ch["loc"][:3] for ch in raw.info["chs"]])
|
||||
ok = np.isfinite(locs).all() and (np.linalg.norm(locs, axis=1) > 0).any()
|
||||
return bool(ok)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def compute_band_powers_for_topomap(raw, bands):
|
||||
data = raw.get_data() # (C,T) V
|
||||
fs = raw.info["sfreq"]
|
||||
psds, freqs = mne.time_frequency.psd_array_welch(
|
||||
data, sfreq=fs,
|
||||
fmin=min(v[0] for v in bands.values()),
|
||||
fmax=max(v[1] for v in bands.values()),
|
||||
n_fft=int(2 * fs),
|
||||
n_overlap=int(1 * fs),
|
||||
average="mean",
|
||||
verbose=False
|
||||
)
|
||||
out = {}
|
||||
for k, (fmin, fmax) in bands.items():
|
||||
idx = np.where((freqs >= fmin) & (freqs <= fmax))[0]
|
||||
|
||||
# 兼容处理:numpy 2.0+ 推荐使用 trapezoid,旧版本用 trapz
|
||||
if hasattr(np, "trapezoid"):
|
||||
bp = np.trapezoid(psds[:, idx], freqs[idx], axis=1) # (C,)
|
||||
else:
|
||||
bp = np.trapz(psds[:, idx], freqs[idx], axis=1) # (C,)
|
||||
|
||||
v = np.log10(bp + 1e-30)
|
||||
v = v - np.mean(v)
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def plot_average_topomap(raw, values, out_dir):
|
||||
fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.6))
|
||||
im, _ = mne.viz.plot_topomap(values, raw.info, axes=ax, show=False, contours=0,sphere=(0, 0, 0, 0.11))
|
||||
ax.set_title("0.8-30 Hz", fontsize=12)
|
||||
plt.colorbar(im, ax=ax, shrink=0.85)
|
||||
plt.tight_layout()
|
||||
out_path = os.path.join(out_dir, "average_topomap.png")
|
||||
plt.savefig(out_path, dpi=200)
|
||||
plt.close(fig)
|
||||
print(f"[OK] average_topomap.png -> {out_path}")
|
||||
|
||||
|
||||
def plot_band_topomaps(raw, band_values, out_dir):
|
||||
order = [
|
||||
("delta", "δ (0.8-3.9Hz)"),
|
||||
("theta", "θ (4-7.9Hz)"),
|
||||
("alpha", "α (8-12.9Hz)"),
|
||||
("beta", "β (13-30Hz)"),
|
||||
("broad", "0.8-30 Hz"),
|
||||
]
|
||||
fig, axes = plt.subplots(1, 5, figsize=(16, 4.2))
|
||||
ims = []
|
||||
for ax, (k, title) in zip(axes, order):
|
||||
im, _ = mne.viz.plot_topomap(band_values[k], raw.info, axes=ax, show=False, contours=0,extrapolate='head',sphere=(0, 0, 0, 0.11))
|
||||
ax.set_title(title, fontsize=11)
|
||||
ims.append(im)
|
||||
fig.subplots_adjust(left=0.02, right=0.85, top=0.88, bottom=0.05, wspace=0.35)
|
||||
cax = fig.add_axes([0.87, 0.15, 0.015, 0.7])
|
||||
fig.colorbar(ims[-1], cax=cax)
|
||||
out_path = os.path.join(out_dir, "topomaps.png")
|
||||
plt.savefig(out_path, dpi=200)
|
||||
plt.close(fig)
|
||||
print(f"[OK] topomaps.png -> {out_path}")
|
||||
|
||||
# ==========================
|
||||
# 生成 ResultData.txt
|
||||
# ==========================
|
||||
def compute_and_save_txt(model_path, mat_dir, out_dir, eeg_uV_tc, fs, ch_names):
|
||||
pred_label = _predict_label_by_model(model_path, mat_dir)
|
||||
recommend = "是" if pred_label == "MDD" else "否"
|
||||
|
||||
T, C = eeg_uV_tc.shape
|
||||
mp = build_channel_index_map(ch_names, C)
|
||||
frontal_idx, central_idx, parietal_idx, prefrontal_idx, posterior_idx, left_idx, right_idx = \
|
||||
get_region_indices(mp, C)
|
||||
|
||||
freqs, pxx = welch_psd(eeg_uV_tc, fs)
|
||||
|
||||
central_alpha = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Alpha"])
|
||||
central_beta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Beta"])
|
||||
frontal_alpha = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Alpha"])
|
||||
frontal_beta = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Beta"])
|
||||
par_alpha = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Alpha"])
|
||||
par_beta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Beta"])
|
||||
|
||||
central_ab = (central_alpha / (central_beta + EPS)) if central_beta > 0 else 0.0
|
||||
frontal_ab = (frontal_alpha / (frontal_beta + EPS)) if frontal_beta > 0 else 0.0
|
||||
par_ab = (par_alpha / (par_beta + EPS)) if par_beta > 0 else 0.0
|
||||
|
||||
central_theta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Theta"])
|
||||
par_theta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Theta"])
|
||||
central_tb = (central_theta / (central_beta + EPS)) if central_beta > 0 else 0.0
|
||||
par_tb = (par_theta / (par_beta + EPS)) if par_beta > 0 else 0.0
|
||||
|
||||
if not left_idx or not right_idx:
|
||||
left_idx = [i for i in prefrontal_idx if (i % 2 == 0)]
|
||||
right_idx = [i for i in prefrontal_idx if (i % 2 == 1)]
|
||||
left_alpha = region_mean_power(freqs, pxx, left_idx, BANDS_METRICS["Alpha"])
|
||||
right_alpha = region_mean_power(freqs, pxx, right_idx, BANDS_METRICS["Alpha"])
|
||||
prefrontal_alpha_asym = float(np.log(right_alpha + EPS) - np.log(left_alpha + EPS))
|
||||
|
||||
iaf = compute_iaf(freqs, pxx, posterior_idx)
|
||||
|
||||
pre_td = region_mean_power(freqs, pxx, prefrontal_idx, (BANDS_METRICS["Delta"][0], BANDS_METRICS["Theta"][1]))
|
||||
pre_total = region_mean_power(freqs, pxx, prefrontal_idx, TOTAL_POWER_BAND)
|
||||
pre_td_rel = (pre_td / (pre_total + EPS)) * 100.0 if pre_total > 0 else 0.0
|
||||
|
||||
def f1(x): return f"{x:.1f}"
|
||||
|
||||
txt = (
|
||||
f"中央区α/β波比值:{f1(central_ab)}\n"
|
||||
f"额区α/β波比值:{f1(frontal_ab)}\n"
|
||||
f"顶区α/β波比值:{f1(par_ab)}\n"
|
||||
f"中央区θ/β波比值:{f1(central_tb)}\n"
|
||||
f"顶区θ/β波比值:{f1(par_tb)}\n"
|
||||
f"前额叶α波不对称性:{f1(prefrontal_alpha_asym)}\n"
|
||||
f"个体化α峰值频率:{f1(iaf)}\n"
|
||||
f"前额叶θ+δ波功率:{f1(pre_td_rel)}\n"
|
||||
f"是否推荐治疗:{recommend}\n"
|
||||
)
|
||||
|
||||
out_path = os.path.join(out_dir, "ResultData.txt")
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
f.write(txt)
|
||||
print(f"[OK] ResultData.txt -> {out_path}")
|
||||
|
||||
|
||||
# ==========================
|
||||
# 一个函数:一次性跑完(txt + 图片)
|
||||
# ==========================
|
||||
def run_all(model_path: str, mat_dir: str, out_root: str, seconds: int = EEG_PLOT_SECONDS):
|
||||
# 1) 选第一个 mat
|
||||
if not os.path.exists(mat_dir):
|
||||
raise RuntimeError(f"输入目录不存在: {mat_dir}")
|
||||
|
||||
mats = [f for f in os.listdir(mat_dir) if f.lower().endswith(".mat")]
|
||||
if not mats:
|
||||
raise RuntimeError(f"mat_dir 下找不到 .mat: {mat_dir}")
|
||||
mats.sort()
|
||||
mat_file = os.path.join(mat_dir, mats[0])
|
||||
print(f"[INFO] Found mat: {mat_file}")
|
||||
|
||||
# 2) 创建输出目录
|
||||
out_dir = ensure_outdir(out_root)
|
||||
print(f"[INFO] Output dir: {out_dir}")
|
||||
|
||||
# --- 总是进行预处理 (默认模式) ---
|
||||
print("[INFO] Mode: Raw Data (Default). Running preprocessing...")
|
||||
temp_dir = os.path.join(out_dir, "temp_preprocessed")
|
||||
mat_file = preprocess_mat_file(mat_file, temp_dir)
|
||||
# 更新 mat_dir 指向临时目录(为了传给 compute_and_save_txt 里的 predict 接口)
|
||||
mat_dir = temp_dir
|
||||
|
||||
# 3) 读 EEG(μV)
|
||||
eeg_uV_tc, fs, ch_names, xyz = load_eeg_from_mat(mat_file)
|
||||
print(f"[INFO] eeg shape(T,C)={eeg_uV_tc.shape}, fs={fs}")
|
||||
|
||||
# 5) 画图:PSD + EEG
|
||||
plot_psd(eeg_uV_tc, fs, ch_names, out_dir)
|
||||
plot_eeg_waveforms(eeg_uV_tc, fs, ch_names, out_dir, seconds=seconds)
|
||||
|
||||
# 6) topomap(有 xyz 才画)
|
||||
try:
|
||||
raw = build_mne_raw_from_uV(eeg_uV_tc, fs, ch_names, xyz)
|
||||
if _raw_has_positions(raw):
|
||||
band_vals = compute_band_powers_for_topomap(raw, BANDS_TOPOMAP)
|
||||
plot_average_topomap(raw, band_vals["broad"], out_dir)
|
||||
plot_band_topomaps(raw, band_vals, out_dir)
|
||||
else:
|
||||
print("[WARN] No valid positions -> skip topomap.")
|
||||
except Exception as e:
|
||||
print(f"[WARN] topomap failed -> skip. reason: {e}")
|
||||
|
||||
# 4) 指标写 txt
|
||||
compute_and_save_txt(model_path, mat_dir, out_dir, eeg_uV_tc, fs, ch_names)
|
||||
|
||||
print("[DONE] txt + figures generated.")
|
||||
return out_dir
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import multiprocessing
|
||||
multiprocessing.freeze_support()
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
# 1. 路径锚定:获取资源绝对路径
|
||||
def get_resource_path(relative_path):
|
||||
"""
|
||||
获取资源的绝对路径。
|
||||
策略:优先在当前执行目录(EXE所在目录)寻找。
|
||||
这适用于“绿色软件”模式,即资源文件(model/raw_data)直接放在EXE旁边。
|
||||
"""
|
||||
if getattr(sys, 'frozen', False):
|
||||
# PyInstaller 打包后的 EXE 所在目录
|
||||
base_path = os.path.dirname(sys.executable)
|
||||
else:
|
||||
# 开发环境:当前脚本所在目录
|
||||
base_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
# 设置默认路径
|
||||
DEFAULT_MODEL = get_resource_path(os.path.join("model", "Model_1.pth"))
|
||||
# 这里我们保持 mat_dir 和 out_root 相对于 EXE 所在目录(或当前工作目录)
|
||||
if getattr(sys, 'frozen', False):
|
||||
EXE_DIR = os.path.dirname(sys.executable)
|
||||
else:
|
||||
EXE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
DEFAULT_MAT = os.path.join(EXE_DIR, "raw_data")
|
||||
DEFAULT_OUT = os.path.join(EXE_DIR, "out")
|
||||
|
||||
# 2. 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description="EEG Depression Assessment Algorithm Integration")
|
||||
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL, help="模型文件的路径 (.pth)")
|
||||
parser.add_argument("--mat_dir", type=str, default=DEFAULT_MAT, help="输入文件夹路径 (包含原始EEG .mat)")
|
||||
parser.add_argument("--out_root", type=str, default=DEFAULT_OUT, help="结果输出的根目录")
|
||||
parser.add_argument("--seconds", type=int, default=10, help="画波形图截取的秒数")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 3. 检查关键路径
|
||||
if not os.path.exists(args.mat_dir):
|
||||
print(f"[WARN] 输入文件夹不存在: {args.mat_dir}")
|
||||
if not os.path.exists(args.model_path):
|
||||
print(f"[WARN] 模型文件不存在: {args.model_path}")
|
||||
|
||||
# 4. 执行主流程
|
||||
print(f"[*] 运行配置:")
|
||||
print(f" - Model : {args.model_path}")
|
||||
print(f" - Input : {args.mat_dir}")
|
||||
print(f" - Output: {args.out_root}")
|
||||
print(f" - Mode : RAW (Auto Preprocess)")
|
||||
|
||||
run_all(args.model_path, args.mat_dir, args.out_root, seconds=args.seconds)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user