Files
Depression_TMS/algorithm_V1/infer_pth.py

562 lines
18 KiB
Python
Raw Normal View History

2026-06-01 13:18:36 +08:00
# -*- coding: utf-8 -*-
"""
infer_pth.py
用途
- 从一个文件夹中自动读取第一个 .mat EEG 文件64通道或32通道
- 若为64通道则按 idx64_to_32 映射选出32通道
- 提取切片特征DE + PSD(var近似)不含Asym
- 加载你训练好的 .pth 模型FusionNet结构
- 输出该受试者的 HC / MDD 判断结果
运行方式命令行
python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth"
也可在其他py里import
from infer_pth_from64_to32 import predict_hc_mdd
res = predict_hc_mdd(eeg_dir, model_path)
print(res)
"""
from __future__ import annotations
import os
import argparse
import numpy as np
import scipy.io
import scipy.signal as signal
import torch
import torch.nn as nn
import torch.nn.functional as F
# =========================================================
# 0) 配置区(按需改这里)
# =========================================================
# 采样率(必须与训练时一致)
SAMPLING_RATE = 250
# 滑窗参数(必须与训练时一致)
WINDOW_SIZE = 500
STRIDE = 250
# 频段(必须与训练时一致)
BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
BANDS = {
"Delta": (1, 4),
"Theta": (4, 8),
"Alpha": (8, 13),
"Beta": (13, 30),
"Gamma": (30, 50),
}
# 是否使用扩展特征DE+PSD
USE_EXTENDED_FEATURES = True
# 数值稳定项
EPS = 1e-12
# 设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 通道映射
IDX64_TO_32 = [
23, # C5
47, # O1
39, # TP7
6, # FPZ
2, # PO6
21, # P4
35, # AF7
57, # AF3
1, # FP2
37, # T7
63, # F1
36, # A1
18, # FC4
31, # FC5
14, # FC2
48, # T8
60, # P2
41, # AF8
11, # CP1
0, # FP1
55, # PO7
59, # C1
22, # F5
10, # CP2
16, # C3
61, # P1
27, # CP5
17, # C4
26, # CP6
62, # F2
3, # POZ
13, # PO5
]
# 推理阈值如果模型checkpoint里有 subject_threshold会优先用它否则用这个
DEFAULT_SUBJECT_THRESHOLD = 0.5
# =========================================================
# 1) 模型结构
# =========================================================
class SEBlock(nn.Module):
def __init__(self, channels: int, reduction: int = 4) -> None:
super().__init__()
hidden = max(1, channels // reduction)
self.fc1 = nn.Linear(channels, hidden)
self.fc2 = nn.Linear(hidden, channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
se = F.relu(self.fc1(x))
se = torch.sigmoid(self.fc2(se))
return x * se
class ResidualBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None:
super().__init__()
self.fc1 = nn.Linear(in_features, out_features)
self.bn1 = nn.BatchNorm1d(out_features)
self.fc2 = nn.Linear(out_features, out_features)
self.bn2 = nn.BatchNorm1d(out_features)
self.dropout = nn.Dropout(dropout)
self.shortcut = nn.Identity()
if in_features != out_features:
self.shortcut = nn.Sequential(
nn.Linear(in_features, out_features),
nn.BatchNorm1d(out_features),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
out = F.relu(self.bn1(self.fc1(x)))
out = self.dropout(out)
out = self.bn2(self.fc2(out))
out = F.relu(out + identity)
return out
class FusionNet(nn.Module):
def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None:
super().__init__()
self.input_norm = nn.BatchNorm1d(num_eeg_features)
self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4)
self.block2 = ResidualBlock(512, 256, dropout=0.3)
self.block3 = ResidualBlock(256, 128, dropout=0.2)
self.attention = SEBlock(128, reduction=4)
self.final_fc = nn.Sequential(
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.2),
)
self.cls_head = nn.Linear(64, num_classes)
# 训练时有回归头也没关系推理只用cls
self.reg_head = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(32, num_scales),
)
self._init_weights()
def _init_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor):
x = self.input_norm(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.attention(x)
features = self.final_fc(x)
cls_out = self.cls_head(features)
reg_out = self.reg_head(features)
return cls_out, reg_out
# =========================================================
# 2) mat读取 + 通道裁剪
# =========================================================
def _find_first_mat_file(folder: str) -> str:
if not os.path.isdir(folder):
raise RuntimeError(f"eeg_dir 不是文件夹: {folder}")
mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")])
if not mats:
raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}")
return os.path.join(folder, mats[0])
import numpy as np
import scipy.io
def _unwrap_singleton(x):
"""
(1,1) / (1,) 这种包裹层一直剥掉直到不是 singleton
也处理 object array 的情况
"""
while True:
if isinstance(x, np.ndarray):
if x.dtype == object and x.size == 1:
x = x.item()
continue
if x.size == 1 and x.ndim >= 1:
# 例如 (1,1) 或 (1,) 的数值/对象数组
try:
x = x.reshape(-1)[0]
continue
except Exception:
pass
break
return x
def _try_get_struct_field(v, field_name="data"):
"""
尝试从以下几种结构中提取字段
1) scipy 读出的 mat_struct _fieldnames
2) numpy structured/record arraydtype.names
"""
# case 1: mat_struct推荐 loadmat(..., struct_as_record=False, squeeze_me=True)
if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])):
return getattr(v, field_name)
# case 2: structured array
if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names):
# 常见是 v[field] 仍然是 ndarray / object需要 unwrap
try:
return v[field_name]
except Exception:
return None
return None
def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray:
"""
读取 .mat EEG 数据支持
- 直接二维矩阵 (T,C) (C,T)
- struct 里有字段 data
返回统一为 float32 (T, C)
"""
# 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
candidates = []
for k, v in mat.items():
if k.startswith("__"):
continue
# --- 1) 直接二维数值矩阵 ---
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
candidates.append((k, v))
continue
# --- 2) struct/record优先提取 data 字段 ---
data_field = _try_get_struct_field(v, "data")
if data_field is not None:
data_field = _unwrap_singleton(data_field)
# data_field 可能仍然被 object 包一层
if isinstance(data_field, np.ndarray) and data_field.dtype == object:
data_field = _unwrap_singleton(data_field)
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
# 只收数值矩阵
if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object:
candidates.append((f"{k}.data", data_field))
continue
# --- 3) object array尝试 item() 解包后再看是不是二维数值矩阵/struct ---
if isinstance(v, np.ndarray) and v.dtype == object:
vv = _unwrap_singleton(v)
# 解包后若是二维数值矩阵
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
candidates.append((k, vv))
continue
# 解包后若是 struct再取 data
data2 = _try_get_struct_field(vv, "data")
if data2 is not None:
data2 = _unwrap_singleton(data2)
if isinstance(data2, np.ndarray) and data2.ndim == 2:
candidates.append((f"{k}.data", data2))
continue
if not candidates:
raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data{mat_path}")
# 选一个最像 EEG 的优先含32/64通道维度的
def score(arr: np.ndarray) -> int:
s = 0
if 64 in arr.shape: s += 10
if 32 in arr.shape: s += 9
if 128 in arr.shape: s += 8
if 129 in arr.shape: s += 7
s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG
return s
candidates.sort(key=lambda kv: score(kv[1]), reverse=True)
key, eeg = candidates[0]
eeg = _unwrap_singleton(eeg)
# 如果还是 object dtype尽力转成 float
if isinstance(eeg, np.ndarray) and eeg.dtype == object:
# 有时 object 里其实是数值
eeg = np.array(eeg, dtype=np.float32)
else:
eeg = np.asarray(eeg, dtype=np.float32)
if eeg.ndim != 2:
raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}")
# 统一为 (T, C)
# 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
eeg = eeg.T
elif eeg.shape[1] in (32, 64, 128, 129):
# 如果第一维也是这些数且更小,可能是(C,T)
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
eeg = eeg.T
return eeg
def ensure_32_channels(eeg: np.ndarray) -> np.ndarray:
"""
输入 (T, C)输出 (T, 32)
- 若C=64 IDX64_TO_32 选32通道
- 若C=32直接返回
"""
if eeg.ndim != 2:
raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}")
C = eeg.shape[1]
if C == 64:
idx = np.asarray(IDX64_TO_32, dtype=np.int64)
if idx.min() < 0 or idx.max() >= 64:
raise RuntimeError(f"IDX64_TO_32 越界min={idx.min()}, max={idx.max()} (要求0~63)")
return eeg[:, idx]
if C == 32:
return eeg
raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。")
# =========================================================
# 3) 特征提取DE + PSD(var近似)
# =========================================================
class FeatureExtractor32:
"""
只针对32通道输出维度
- USE_EXTENDED_FEATURES=TrueDE(32*5) + PSD(32*5) = 320
- 否则DE(32*5) = 160
"""
def __init__(
self,
fs: int = SAMPLING_RATE,
window_size: int = WINDOW_SIZE,
stride: int = STRIDE,
filter_order: int = 4,
zero_phase: bool = False,
) -> None:
self.fs = fs
self.window_size = window_size
self.stride = stride
self.filter_order = filter_order
self.zero_phase = zero_phase
self._sos = {}
for bn in BAND_NAMES:
low, high = BANDS[bn]
self._sos[bn] = signal.butter(
self.filter_order, [low, high],
btype="band", fs=self.fs, output="sos"
)
def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]:
out = {}
for bn in BAND_NAMES:
sos = self._sos[bn]
if self.zero_phase:
out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32)
else:
out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32)
return out
def extract(self, eeg32: np.ndarray) -> np.ndarray:
"""
eeg32: (T, 32)
return: feats (N_slices, feat_dim)
"""
if eeg32.ndim != 2 or eeg32.shape[1] != 32:
raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}")
bands_data = self._filter_bands(eeg32)
T = eeg32.shape[0]
feats = []
for start in range(0, T - self.window_size, self.stride):
end = start + self.window_size
de_list = []
psd_list = []
for bn in BAND_NAMES:
seg = bands_data[bn][start:end, :] # (W, 32)
var = np.var(seg, axis=0, ddof=1) # (32,)
# DE
de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS))
de_list.append(de)
if USE_EXTENDED_FEATURES:
# PSD近似log(var)
psd_list.append(np.log(var + EPS))
de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,)
if USE_EXTENDED_FEATURES:
psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,)
f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32)
else:
f = de_feat.astype(np.float32)
feats.append(f)
if not feats:
raise RuntimeError("EEG长度不足以切片请检查T是否太短或调整WINDOW_SIZE/STRIDE")
return np.stack(feats, axis=0).astype(np.float32)
# =========================================================
# 4) 模型加载 + 推理接口
# =========================================================
def _safe_torch_load(path: str):
try:
return torch.load(path, map_location=DEVICE, weights_only=False)
except TypeError:
return torch.load(path, map_location=DEVICE)
def load_model(model_path: str) -> tuple[FusionNet, dict]:
"""
返回: (model, ckpt_dict)
"""
obj = _safe_torch_load(model_path)
if isinstance(obj, dict) and "model_state" in obj:
ckpt = obj
state = obj["model_state"]
feat_dim = int(obj.get("feat_dim", 320))
else:
ckpt = {}
state = obj
feat_dim = 320
model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE)
model.load_state_dict(state, strict=True)
model.eval()
return model, ckpt
def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict:
"""
接口传入 EEG文件夹 模型路径返回判断结果 dict
返回字段
- mat_file: 使用的mat文件
- pred_label: "HC" or "MDD"
- p_mdd_mean: 切片p(MDD)均值
- threshold: subject判定阈值
- n_slices: 切片数
"""
mat_file = _find_first_mat_file(eeg_dir)
# 1) 读EEG (T,C)并保证变成32通道
eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C)
eeg32 = ensure_32_channels(eeg) # (T,32)
# 2) 提特征 (N,feat_dim)
extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE)
feats = extractor.extract(eeg32) # (N, dim)
# 3) 加载模型
model, ckpt = load_model(model_path)
# 4) 可选归一化若ckpt里保存的mean/std维度刚好匹配
mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None
std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None
if mean is not None and std is not None:
mean = np.asarray(mean, dtype=np.float32)
std = np.asarray(std, dtype=np.float32)
if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]:
feats = (feats - mean) / (std + 1e-8)
# 不匹配就跳过
# 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理
# 5) 推理:对所有切片算 p(MDD)取均值做subject-level
x = torch.from_numpy(feats).to(DEVICE)
with torch.no_grad():
cls_out, _ = model(x)
prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy()
p_mdd_mean = float(np.mean(prob_mdd))
thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD)
pred_is_mdd = (p_mdd_mean >= thr)
pred_label = "MDD" if pred_is_mdd else "HC"
return {
"mat_file": mat_file,
"pred_label": pred_label,
"p_mdd_mean": p_mdd_mean,
"threshold": thr,
"n_slices": int(feats.shape[0]),
}
# =========================================================
# 5) CLI命令行运行入口
# =========================================================
def main():
parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).")
parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹自动读取第一个.mat")
parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径")
args = parser.parse_args()
res = predict_hc_mdd(args.eeg_dir, args.model_path)
print("\n========== 推理结果 ==========")
print(f"MAT文件: {res['mat_file']}")
print(f"切片数量: {res['n_slices']}")
print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}")
print(f"阈值thr: {res['threshold']:.4f}")
print(f"预测结果: {res['pred_label']}")
print("==============================\n")
if __name__ == "__main__":
main()