初始化zmq 项目
This commit is contained in:
851
Tools/plot_MI_EEG.py
Normal file
851
Tools/plot_MI_EEG.py
Normal file
@@ -0,0 +1,851 @@
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import os
|
||||
import io
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Ellipse
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.colors as mcolors
|
||||
from scipy.spatial import Delaunay
|
||||
from scipy.interpolate import Rbf
|
||||
from scipy.signal import welch
|
||||
from scipy.stats import sem
|
||||
from scipy.signal import butter, filtfilt, hilbert
|
||||
import base64
|
||||
|
||||
# 位置坐标
|
||||
def read_ch_pos(file_path=r'xy_64.xlsx'):
|
||||
"""
|
||||
将电极位置信息转换为Dict
|
||||
|
||||
参数:
|
||||
file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'列
|
||||
|
||||
"""
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(script_dir,file_path )
|
||||
df = pd.read_excel(file_path)
|
||||
# 确保列名正确
|
||||
if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']):
|
||||
raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'列")
|
||||
# 创建电极位置字典
|
||||
ch_pos = {}
|
||||
for _, row in df.iterrows():
|
||||
ch_pos[row['channel']] = [row['x'], row['y'], row['z']]
|
||||
return ch_pos
|
||||
# 头部轮廓
|
||||
def draw_head(ax, center=(0, 0), radius=1.0, zorder=4):
|
||||
"""
|
||||
绘制头部轮廓、鼻子和耳朵。
|
||||
|
||||
参数:
|
||||
- ax : matplotlib Axes 对象
|
||||
- center : (x, y) 头中心坐标
|
||||
- radius : float, 头半径
|
||||
- zorder : 绘制层级
|
||||
"""
|
||||
|
||||
# 头圆
|
||||
head = plt.Circle(center, radius, fill=False, color='k', linewidth=1, zorder=zorder)
|
||||
ax.add_artist(head)
|
||||
|
||||
# 鼻子(参考 _make_head_outlines)
|
||||
dx = np.exp(np.arccos(np.deg2rad(12)) * 1j)
|
||||
dx_real, dx_imag = dx.real, dx.imag
|
||||
nose_x = np.array([-dx_real, 0, dx_real]) * radius + center[0]
|
||||
nose_y = np.array([dx_imag, 1.15, dx_imag]) * radius + center[1]
|
||||
ax.plot(nose_x, nose_y, color='k', linewidth=1, zorder=zorder)
|
||||
|
||||
# 耳朵(参考 _make_head_outlines 手动标定)
|
||||
ear_radius = radius * 0.12
|
||||
ear_scale = radius * 2 # 根据半径缩放
|
||||
theta = np.linspace(np.pi / 2, 3 * np.pi / 2, 30)
|
||||
|
||||
# 左耳
|
||||
left_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419,
|
||||
0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale
|
||||
left_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555,
|
||||
-0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1]
|
||||
ax.plot(center[0] - left_ear_x_array, left_ear_y_array, color='k', linewidth=1, zorder=zorder)
|
||||
|
||||
# 右耳
|
||||
right_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419,
|
||||
0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale
|
||||
right_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555,
|
||||
-0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1]
|
||||
ax.plot(center[0] + right_ear_x_array, right_ear_y_array, color='k', linewidth=1, zorder=zorder)
|
||||
# 地形图 插值
|
||||
def rbf_D_interpolate(xy, v, center=(0, 0), radius=1.1, grid_res=300,
|
||||
n_extra=32, rbf_func='multiquadric', smooth=0,
|
||||
border='mean', border_scale=1.0001, n_ngb=4):
|
||||
"""
|
||||
使用 RBF + Delaunay 邻域均值方式生成平滑的 EEG topomap 插值表面。
|
||||
|
||||
参数
|
||||
----
|
||||
xy : (N,2) array
|
||||
电极二维坐标(与绘图坐标系一致)
|
||||
v : (N,) array
|
||||
每个电极对应的值(e.g. PSD)
|
||||
center : tuple (x0, y0)
|
||||
头部圆心(默认 (0,0))
|
||||
radius : float
|
||||
头部半径(用于生成边界点与网格范围)
|
||||
grid_res : int
|
||||
网格分辨率(每轴点数)
|
||||
n_extra : int
|
||||
边界虚拟点数量
|
||||
rbf_func : str
|
||||
RBF 内核名称('multiquadric','thin_plate','gaussian',...)
|
||||
smooth : float
|
||||
RBF 平滑参数
|
||||
border : 'mean' or float
|
||||
若 'mean':边界点用邻近真实通道均值赋值(推荐)
|
||||
若 float:边界点赋相同常数值
|
||||
border_scale : float
|
||||
边界点半径相对 radius 的缩放(略微 >1 用以外推)
|
||||
n_ngb : int
|
||||
为每个边界点取值时使用的最近真实通道数
|
||||
|
||||
返回
|
||||
----
|
||||
zi : (grid_res, grid_res) ndarray
|
||||
插值结果(与 grid_x, grid_y 对齐)
|
||||
grid_x, grid_y : ndarrays
|
||||
meshgrid(由 np.meshgrid 生成)
|
||||
"""
|
||||
xy = np.asarray(xy)
|
||||
v = np.asarray(v)
|
||||
if xy.ndim != 2 or xy.shape[1] != 2:
|
||||
raise ValueError("xy must be shape (n_channels, 2)")
|
||||
|
||||
n_points = xy.shape[0]
|
||||
|
||||
# --- 1. 生成边界虚拟点(圆周) ---
|
||||
theta = np.linspace(0.0, 2 * np.pi, n_extra, endpoint=False)
|
||||
r_border = radius * border_scale
|
||||
border_xy = np.column_stack([center[0] + r_border * np.cos(theta),
|
||||
center[1] + r_border * np.sin(theta)])
|
||||
|
||||
# --- 2. 用 Delaunay 建图以便找到邻居(对边界点取邻居均值) ---
|
||||
# 合并用于三角化的位置(真实点 + 边界点)
|
||||
tri_xy = np.vstack([xy, border_xy])
|
||||
tri = Delaunay(tri_xy)
|
||||
|
||||
# --- 3. 为边界点赋值 ---
|
||||
if isinstance(border, str) and border == 'mean':
|
||||
# 使用 Delaunay 的 vertex_neighbor_vertices 索引
|
||||
# 注意:tri.vertex_neighbor_vertices 给出 vertices -> neighbor indptr
|
||||
indices, indptr = tri.vertex_neighbor_vertices
|
||||
v_extra = np.zeros(n_extra)
|
||||
used = np.zeros(n_extra, dtype=bool)
|
||||
# 边界点在 tri_xy 中的索引范围
|
||||
rng = range(n_points, n_points + n_extra)
|
||||
for idx, extra_idx in enumerate(rng):
|
||||
neigh = indptr[indices[extra_idx]:indices[extra_idx + 1]]
|
||||
# 仅保留原始点索引(小于 n_points)
|
||||
neigh = neigh[neigh < n_points]
|
||||
if neigh.size > 0:
|
||||
used[idx] = True
|
||||
# 使用最近 n_ngb 个邻居的均值(若邻居多则取最近的 n_ngb)
|
||||
if neigh.size > n_ngb:
|
||||
# 计算距离并选取最近 n_ngb
|
||||
d = np.linalg.norm(xy[neigh] - tri_xy[extra_idx], axis=1)
|
||||
order = np.argsort(d)[:n_ngb]
|
||||
sel = neigh[order]
|
||||
else:
|
||||
sel = neigh
|
||||
v_extra[idx] = v[sel].mean()
|
||||
if not used.all() and used.any():
|
||||
v_extra[~used] = np.mean(v_extra[used])
|
||||
elif not used.any():
|
||||
v_extra[:] = np.mean(v)
|
||||
else:
|
||||
# border 是数值
|
||||
v_extra = np.full(n_extra, float(border))
|
||||
|
||||
# --- 4. 合并所有已知点并构建 RBF ---
|
||||
all_xy = np.vstack([xy, border_xy])
|
||||
all_v = np.concatenate([v, v_extra])
|
||||
|
||||
rbf = Rbf(all_xy[:, 0], all_xy[:, 1], all_v,
|
||||
function=rbf_func, smooth=smooth)
|
||||
|
||||
# --- 5. 生成网格(使用 meshgrid,与主函数保持一致) ---
|
||||
xmin, xmax = center[0] - radius, center[0] + radius
|
||||
ymin, ymax = center[1] - radius, center[1] + radius
|
||||
xi = np.linspace(xmin, xmax, grid_res)
|
||||
yi = np.linspace(ymin, ymax, grid_res)
|
||||
grid_x, grid_y = np.meshgrid(xi, yi) # meshgrid 与 imshow 对齐
|
||||
|
||||
# --- 6. 评估 RBF,返回与 grid 对齐的 zi ---
|
||||
zi = rbf(grid_x, grid_y)
|
||||
|
||||
return zi, grid_x, grid_y
|
||||
# plv矩阵计算
|
||||
def calculate_plv(data):
|
||||
"""
|
||||
计算相位锁定值(PLV)矩阵。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : ndarray, shape (num_channels, num_samples)
|
||||
EEG 数据,通道数为 num_channels,样本数为 num_samples。
|
||||
|
||||
Returns
|
||||
-------
|
||||
plv_matrix : ndarray, shape (num_channels, num_channels)
|
||||
计算得到的 PLV 矩阵,表示各通道间的相位同步。
|
||||
"""
|
||||
num_channels, num_samples = data.shape
|
||||
plv_matrix = np.zeros((num_channels, num_channels))
|
||||
|
||||
# 计算每个通道的解析信号
|
||||
analytic_signals = np.apply_along_axis(hilbert, axis=1, arr=data)
|
||||
|
||||
for i in range(num_channels):
|
||||
for j in range(i + 1, num_channels): # 只计算上三角矩阵,避免重复计算
|
||||
# 计算 phase difference
|
||||
phase_diff = np.angle(analytic_signals[i] * np.conj(analytic_signals[j]))
|
||||
plv = np.abs(np.mean(np.exp(1j * phase_diff)))
|
||||
plv_matrix[i, j] = plv
|
||||
plv_matrix[j, i] = plv # 对称矩阵
|
||||
|
||||
return plv_matrix
|
||||
# 矩阵阈值化
|
||||
def threshold_proportional(adj, prop=0.2):
|
||||
"""
|
||||
Apply a proportional threshold to retain the top proportion of strongest edges.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
adj : ndarray, shape (n_channels, n_channels)
|
||||
Adjacency matrix to threshold.
|
||||
prop : float
|
||||
Proportion of edges to retain (0 < prop <= 1).
|
||||
|
||||
Returns
|
||||
-------
|
||||
bin_adj : ndarray, shape (n_channels, n_channels)
|
||||
Binary adjacency matrix after thresholding.
|
||||
"""
|
||||
n = adj.shape[0]
|
||||
triu_idx = np.triu_indices(n, k=1)
|
||||
weights = adj[triu_idx]
|
||||
k = int(np.floor(len(weights) * prop))
|
||||
|
||||
# Ensure that at least one edge is retained
|
||||
k = max(k, 1)
|
||||
|
||||
# Get the threshold value
|
||||
thr = np.sort(weights)[-k]
|
||||
|
||||
# Apply the threshold to create a binary adjacency matrix
|
||||
bin_adj = np.where(adj >= thr, adj, 0.0)
|
||||
|
||||
return bin_adj
|
||||
# 单个脑网络
|
||||
def plot_single_network(ch_names,adj,ax=None,
|
||||
node_size=20, node_color='orange',highlight_nodes=[], show_names=True,
|
||||
edge_color='gray', weighted=True,
|
||||
radius=1.1, figsize=(6, 6),cmap='RdYlBu_r'):
|
||||
# 若 ax 未传入,则自己创建
|
||||
own_fig = False
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
own_fig = True
|
||||
else:
|
||||
fig = ax.figure
|
||||
|
||||
# 坐标归一化
|
||||
pos3d = read_ch_pos()
|
||||
all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()])
|
||||
all_chs_xy -= all_chs_xy.mean(axis=0)
|
||||
all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max()
|
||||
xy_dict = dict(zip(pos3d.keys(), all_chs_xy))
|
||||
xy = np.array([xy_dict[ch] for ch in ch_names])
|
||||
center = xy_dict.get('CZ', np.mean(list(xy_dict.values()), axis=0))
|
||||
|
||||
# ===== 初始化绘图窗口 =====
|
||||
ax.set_aspect('equal')
|
||||
ax.axis('off')
|
||||
# 设置边界(与原类保持一致)
|
||||
ear_radius = radius * 0.12
|
||||
nose_height = radius * 0.15
|
||||
margin_x = radius * 0.12 + 0.05
|
||||
ax.set_xlim(center[0] - radius - margin_x, center[0] + radius + margin_x)
|
||||
ax.set_ylim(center[1] - radius - ear_radius, center[1] + radius + nose_height + ear_radius)
|
||||
|
||||
# 绘制头部轮廓
|
||||
draw_head(ax, center=center, radius=radius)
|
||||
|
||||
# 节点
|
||||
for ch in ch_names:
|
||||
color = 'red' if ch in highlight_nodes else node_color
|
||||
ax.scatter(*xy_dict[ch], s=node_size, color=color, edgecolor='k', zorder=4)
|
||||
if show_names:
|
||||
ax.text(xy_dict[ch][0], xy_dict[ch][1] + 0.03, ch,
|
||||
ha='center', va='bottom', fontsize=8, zorder=5)
|
||||
|
||||
# colorbar
|
||||
norm = mcolors.Normalize(vmin=0, vmax=1)
|
||||
color_map = matplotlib.colormaps.get_cmap(cmap)
|
||||
# ========= 边 ==========
|
||||
N = len(ch_names)
|
||||
for i in range(N):
|
||||
for j in range(i + 1, N):
|
||||
w = adj[i, j]
|
||||
if w > 0:
|
||||
x = [xy[i, 0], xy[j, 0]]
|
||||
y = [xy[i, 1], xy[j, 1]]
|
||||
lw = 1.5
|
||||
if weighted:
|
||||
ax.plot(x, y,
|
||||
color=color_map(norm(w)),
|
||||
linewidth=lw,
|
||||
alpha=0.7,
|
||||
zorder=3)
|
||||
else:
|
||||
ax.plot(x, y,
|
||||
color=edge_color,
|
||||
linewidth=lw,
|
||||
alpha=0.7,
|
||||
zorder=3)
|
||||
|
||||
if own_fig:
|
||||
# 不回传 添加颜色条
|
||||
sm = cm.ScalarMappable(norm=norm, cmap=color_map)
|
||||
cbar = plt.colorbar(sm, ax=ax, fraction=0.035)
|
||||
cbar.set_label('Connection Strength', fontsize=10)
|
||||
cbar.ax.tick_params(direction='in', labelsize=10)
|
||||
plt.show()
|
||||
return fig
|
||||
else:
|
||||
|
||||
return ax
|
||||
# 脑网络对比
|
||||
def plot_multiband_network(ch_names, adj_MI, adj_Rest,cmap='RdYlBu_r'):
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
|
||||
fontsize = 16
|
||||
fig.text(0.285, 0.08, 'MI', fontsize=fontsize, ha='center', va='center', rotation=0)
|
||||
fig.text(0.68, 0.08, 'Rest', fontsize=fontsize, ha='center', va='center', rotation=0)
|
||||
|
||||
im1 = plot_single_network(ch_names,adj_MI,ax=axes[0], show_names=True,cmap=cmap)
|
||||
# Rest 行
|
||||
im2 = plot_single_network(ch_names,adj_Rest,ax=axes[1],show_names=True,cmap=cmap)
|
||||
|
||||
# --- 合并 colorbar(右侧一个) ---
|
||||
norm = mcolors.Normalize(vmin=0, vmax=1)
|
||||
color_map = matplotlib.colormaps.get_cmap(cmap)
|
||||
sm = cm.ScalarMappable(norm=norm, cmap=color_map)
|
||||
cbar = plt.colorbar(sm, ax=axes.ravel().tolist(), fraction=0.02)
|
||||
cbar.set_label('Connection Strength', fontsize=10)
|
||||
cbar.ax.tick_params(direction='in', labelsize=10)
|
||||
|
||||
# 将图像保存到内存字节流(PNG 格式)
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||
plt.close(fig) # 释放内存
|
||||
buf.seek(0)
|
||||
image_bytes = buf.read()
|
||||
buf.close()
|
||||
|
||||
return image_bytes
|
||||
|
||||
# 多个频带psd
|
||||
def compute_band_psd(eeg, fs, bands, labels, trial_idx=0,MI_label=1, Rest_label=2,avg = True):
|
||||
"""
|
||||
eeg: (n_trials, n_channels, n_samples)
|
||||
"""
|
||||
n_trials, n_channels, n_samples = eeg.shape
|
||||
band_names = list(bands.keys())
|
||||
n_bands = len(band_names)
|
||||
|
||||
psd_MI = np.zeros((n_bands, n_channels))
|
||||
psd_Rest = np.zeros((n_bands, n_channels))
|
||||
|
||||
# 先计算所有 trial 的功率谱
|
||||
f, Pxx = welch(eeg, fs=fs, axis=-1, nperseg=fs,noverlap = fs // 2)
|
||||
|
||||
|
||||
for bi, (bname, (f1, f2)) in enumerate(bands.items()):
|
||||
idx = np.logical_and(f >= f1, f <= f2)
|
||||
band_power = Pxx[:, :, idx].mean(axis=-1)
|
||||
|
||||
band_power_flat = band_power.flatten()
|
||||
power_min = band_power_flat.min()
|
||||
power_max = band_power_flat.max()
|
||||
if power_max - power_min > 1e-12:
|
||||
band_power_norm = (band_power - power_min) / (power_max - power_min)
|
||||
else:
|
||||
band_power_norm = band_power
|
||||
|
||||
if avg:
|
||||
psd_MI[bi] = band_power_norm[labels == MI_label].mean(axis=0)
|
||||
psd_Rest[bi] = band_power_norm[labels == Rest_label].mean(axis=0)
|
||||
else:
|
||||
psd_MI[bi] = band_power_norm[labels == MI_label][trial_idx]
|
||||
psd_Rest[bi] = band_power_norm[labels == Rest_label][trial_idx]
|
||||
return band_names, psd_MI, psd_Rest
|
||||
# 单个脑地形图
|
||||
def plot_single_topomap(ch_names, psd_values, cmap='RdYlBu_r', vlim=(0, 1),
|
||||
show_names=True, node_size=3, radius=1.1, grid_res=300,
|
||||
n_contours=None, contour_color='k',
|
||||
ax=None,figsize=(6,6)):
|
||||
# 若 ax 未传入,则自己创建
|
||||
own_fig = False
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
own_fig = True
|
||||
else:
|
||||
fig = ax.figure
|
||||
|
||||
# ===== 初始化绘图窗口 =====
|
||||
ax.set_aspect('equal')
|
||||
ax.axis('off')
|
||||
# ax.set_title("EEG topomap (MNE-like)")
|
||||
|
||||
# 坐标归一化
|
||||
pos3d = read_ch_pos()
|
||||
all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()])
|
||||
all_chs_xy -= all_chs_xy.mean(axis=0)
|
||||
all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max()
|
||||
pos2d_dict = dict(zip(pos3d.keys(), all_chs_xy))
|
||||
xy = np.array([pos2d_dict[ch] for ch in ch_names])
|
||||
center = pos2d_dict.get('CZ', np.mean(list(pos2d_dict.values()), axis=0))
|
||||
|
||||
# 绘制头部轮廓
|
||||
draw_head(ax, center=center, radius=radius)
|
||||
# 绘制电极
|
||||
fontsize = 4
|
||||
ax.scatter(xy[:, 0], xy[:, 1], c='k', s=node_size, zorder=5)
|
||||
if show_names:
|
||||
for i, ch in enumerate(ch_names):
|
||||
ax.text(xy[i, 0], xy[i, 1] + 0.03, ch,
|
||||
ha='center', va='bottom', fontsize=fontsize, zorder=6)
|
||||
|
||||
# 数据插值
|
||||
zi, grid_x, grid_y = rbf_D_interpolate(
|
||||
xy, psd_values, radius=radius,
|
||||
grid_res=grid_res
|
||||
)
|
||||
xmin, xmax = center[0] - radius, center[0] + radius
|
||||
ymin, ymax = center[1] - radius, center[1] + radius
|
||||
extent = (xmin, xmax, ymin, ymax)
|
||||
im = ax.imshow(zi, extent=extent, origin='lower',
|
||||
cmap=cmap, vmin=vlim[0], vmax=vlim[1],
|
||||
interpolation='bicubic', zorder=0)
|
||||
# 裁剪路径
|
||||
patch_ = Ellipse(center, 2 * radius, 2 * radius, clip_on=True, transform=ax.transData)
|
||||
im.set_clip_path(patch_)
|
||||
# 初始等高线
|
||||
linewidths = 0.5
|
||||
if n_contours is None:
|
||||
cset = ax.contour(grid_x, grid_y, zi,
|
||||
colors=contour_color, linewidths=linewidths, zorder=2)
|
||||
else:
|
||||
cset = ax.contour(grid_x, grid_y, zi, levels=n_contours,
|
||||
colors=contour_color, linewidths=linewidths, zorder=2)
|
||||
cset.set_clip_path(patch_)
|
||||
|
||||
|
||||
|
||||
if own_fig:
|
||||
# 不回传 添加颜色条
|
||||
plt.colorbar(im, ax=ax, fraction=0.035)
|
||||
plt.show()
|
||||
return fig
|
||||
else:
|
||||
# plt.colorbar(im, ax=ax, fraction=0.035)
|
||||
return im
|
||||
# 脑地形图对比
|
||||
def plot_multiband_topomaps(ch_names, psd_MI, psd_Rest, bands):
|
||||
band_names = list(bands.keys()) # 改动 1:新增这行
|
||||
n_bands = len(band_names)
|
||||
fig, axes = plt.subplots(2, n_bands, figsize=(3*n_bands, 6))
|
||||
|
||||
fontsize = 16
|
||||
|
||||
axes[0, 0].text(-0.1, 0.5, 'MI', transform=axes[0, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2)
|
||||
axes[1, 0].text(-0.1, 0.5, 'Rest', transform=axes[1, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2)
|
||||
|
||||
imgs = []
|
||||
for i, bname in enumerate(band_names):
|
||||
axes[0, i].set_title(bname, fontsize=fontsize, pad=0)
|
||||
# MI 行
|
||||
im1 = plot_single_topomap(ch_names,psd_MI[i],ax=axes[0, i], show_names=True)
|
||||
# Rest 行
|
||||
im2 = plot_single_topomap(ch_names,psd_Rest[i],ax=axes[1, i],show_names=True)
|
||||
imgs.append(im1)
|
||||
|
||||
# --- 单个右侧合并 colorbar ---
|
||||
cbar = fig.colorbar(imgs[0], ax=axes,fraction=0.02)
|
||||
# cbar.set_label("PSD Power",fontsize=fontsize-4)
|
||||
cbar.ax.tick_params(direction='in', labelsize=10)
|
||||
|
||||
# 将图像保存到内存字节流(PNG 格式)
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||
plt.close(fig) # 释放内存
|
||||
buf.seek(0)
|
||||
image_bytes = buf.read()
|
||||
buf.close()
|
||||
|
||||
return image_bytes
|
||||
|
||||
# 小波
|
||||
def morlet_wavelet(f, fs, n_cycles=7):
|
||||
"""
|
||||
创建 Morlet 小波
|
||||
f: 频率
|
||||
fs: 采样率
|
||||
"""
|
||||
sigma_t = n_cycles / (2 * np.pi * f)
|
||||
t = np.arange(-3*sigma_t, 3*sigma_t, 1/fs)
|
||||
wavelet = (np.pi**-0.25) * np.exp(2j*np.pi*f*t) * np.exp(-(t**2)/(2*sigma_t**2))
|
||||
return wavelet
|
||||
|
||||
|
||||
# 希尔伯特变换 计算ERDS 效果不佳
|
||||
def bandpass_filter(data, fs, band, order=4):
|
||||
nyq = fs / 2
|
||||
b, a = butter(order, [band[0]/nyq, band[1]/nyq], btype='band')
|
||||
return filtfilt(b, a, data, axis=-1)
|
||||
def compute_power_hilbert(filtered_data,is_dB =True):
|
||||
analytic = hilbert(filtered_data, axis=-1)
|
||||
power = np.abs(analytic) ** 2
|
||||
if is_dB:
|
||||
power = 10 * np.log10(power)
|
||||
return power
|
||||
def compute_power(data, fs=250,
|
||||
bands={"mu": (8,12), "beta": (13,30)}):
|
||||
"""
|
||||
返回:
|
||||
power_dict[band] = (n_trials, n_ch, n_samples)
|
||||
"""
|
||||
power_dict = {}
|
||||
for band_name, band_range in bands.items():
|
||||
filt = bandpass_filter(data, fs, band_range)
|
||||
power = compute_power_hilbert(filt)
|
||||
power_dict[band_name] = power
|
||||
|
||||
return power_dict
|
||||
|
||||
def compute_erds(power_MI, power_Rest, baseline_period=None):
|
||||
"""
|
||||
计算事件相关去同步/同步 (ERDS)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
power_MI, power_Rest: (n_trials, n_ch, n_samples)
|
||||
功率数据,单位为 µV² 或 dB(取决于 compute_power_hilbert 的 is_dB 参数)
|
||||
baseline_period: tuple (start_idx, end_idx) or None
|
||||
基线时间段索引。如果为None,使用 Rest 状态的平均值作为基线
|
||||
|
||||
返回:
|
||||
MI_erds_mean, MI_erds_sem
|
||||
Rest_erds_mean, Rest_erds_sem
|
||||
所有返回值的形状为 (n_ch, n_samples)
|
||||
"""
|
||||
|
||||
if baseline_period is not None:
|
||||
start_idx, end_idx = baseline_period
|
||||
baseline = np.concatenate([power_MI[:, :, start_idx:end_idx],
|
||||
power_Rest[:, :, start_idx:end_idx]], axis=0)
|
||||
baseline = baseline.mean(axis=(0, 2), keepdims=True)
|
||||
else:
|
||||
baseline = power_Rest.mean(axis=(0,2), keepdims=True)
|
||||
|
||||
# === ERDS (%) ===
|
||||
MI_erds = (power_MI - baseline) / baseline * 100
|
||||
Rest_erds = (power_Rest - baseline) / baseline * 100
|
||||
|
||||
return (
|
||||
MI_erds.mean(axis=0), sem(MI_erds, axis=0),
|
||||
Rest_erds.mean(axis=0), sem(Rest_erds, axis=0),
|
||||
)
|
||||
|
||||
def compute_all_erds(MI_power_dict, Rest_power_dict):
|
||||
"""
|
||||
对多个频带同时计算 ERDS。
|
||||
|
||||
输入:
|
||||
MI_power_dict[band] = (n_trials, n_ch, n_samples)
|
||||
Rest_power_dict[band] = (n_trials, n_ch, n_samples)
|
||||
|
||||
输出:
|
||||
erds_MI[band] = (mean, sem)
|
||||
erds_Rest[band] = (mean, sem)
|
||||
"""
|
||||
|
||||
erds_MI = {}
|
||||
erds_Rest = {}
|
||||
|
||||
for band in MI_power_dict.keys():
|
||||
MI_power = MI_power_dict[band]
|
||||
Rest_power = Rest_power_dict[band]
|
||||
|
||||
MI_mean, MI_sem, Rest_mean, Rest_sem = compute_erds(MI_power, Rest_power)
|
||||
|
||||
erds_MI[band] = (MI_mean, MI_sem)
|
||||
erds_Rest[band] = (Rest_mean, Rest_sem)
|
||||
|
||||
return erds_MI, erds_Rest
|
||||
|
||||
def plot_compare_erds(data_MI, data_Rest, mode="power",
|
||||
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'],
|
||||
compare_names=['C3', 'CZ', 'C4'], bands=['mu', 'beta'],
|
||||
fs=250, t=None, figsize=(12,6)):
|
||||
|
||||
n_bands = len(bands)
|
||||
n_chs = len(compare_names)
|
||||
|
||||
# 自动添加单位
|
||||
if mode == "power":
|
||||
# y_unit = "Power (µV²)"
|
||||
y_unit = "Power (dB)"
|
||||
elif mode == "erds":
|
||||
y_unit = "ERDS (%)"
|
||||
else:
|
||||
y_unit = ""
|
||||
|
||||
if t is None:
|
||||
n_samples = next(iter(data_MI.values())).shape[-1] \
|
||||
if mode=="power" else next(iter(data_MI.values()))[0].shape[-1]
|
||||
t = np.arange(n_samples) / fs
|
||||
|
||||
fig, axes = plt.subplots(n_bands, n_chs, figsize=figsize, sharex=True, sharey=True)
|
||||
|
||||
for i, band in enumerate(bands):
|
||||
|
||||
# 选择数据结构
|
||||
if mode == "power":
|
||||
MI_band = data_MI[band] # (trials, ch, samples)
|
||||
Rest_band = data_Rest[band]
|
||||
|
||||
avg_MI = MI_band.mean(axis=0)
|
||||
sem_MI = MI_band.std(axis=0)/np.sqrt(MI_band.shape[0])
|
||||
|
||||
avg_Rest = Rest_band.mean(axis=0)
|
||||
sem_Rest = Rest_band.std(axis=0)/np.sqrt(Rest_band.shape[0])
|
||||
|
||||
elif mode == "erds":
|
||||
avg_MI, sem_MI = data_MI[band]
|
||||
avg_Rest, sem_Rest = data_Rest[band]
|
||||
|
||||
for j, ch in enumerate(compare_names):
|
||||
ax = axes[i, j] if n_bands > 1 else axes[j]
|
||||
|
||||
ch_idx = ch_names.index(ch)
|
||||
|
||||
# 绘制 MI
|
||||
ax.plot(t, avg_MI[ch_idx], color="C0", label="MI")
|
||||
ax.fill_between(t,
|
||||
avg_MI[ch_idx]-sem_MI[ch_idx],
|
||||
avg_MI[ch_idx]+sem_MI[ch_idx],
|
||||
alpha=0.3, color="C0")
|
||||
|
||||
# 绘制 Rest
|
||||
ax.plot(t, avg_Rest[ch_idx], color="C1", label="Rest")
|
||||
ax.fill_between(t,
|
||||
avg_Rest[ch_idx]-sem_Rest[ch_idx],
|
||||
avg_Rest[ch_idx]+sem_Rest[ch_idx],
|
||||
alpha=0.3, color="C1")
|
||||
|
||||
if i == 0:
|
||||
ax.set_title(ch)
|
||||
|
||||
# ← Y 轴加单位
|
||||
if j == 0:
|
||||
ax.set_ylabel(f"{band}\n{y_unit}")
|
||||
|
||||
if i == n_bands - 1:
|
||||
ax.set_xlabel("Time (s)")
|
||||
|
||||
ax.grid(alpha=0.3)
|
||||
|
||||
if i == 0 and j == n_chs - 1:
|
||||
ax.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 将图像保存到内存字节流(PNG 格式)
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||
plt.close(fig) # 释放内存
|
||||
buf.seek(0)
|
||||
image_bytes = buf.read()
|
||||
buf.close()
|
||||
|
||||
return image_bytes
|
||||
|
||||
# 对比 MI vs Rest 的功率谱密度 PSD
|
||||
def plot_psd_compare(MI_data, Rest_data, ch_names, compare_names=['C3', 'CZ', 'C4'],
|
||||
fs=250, nperseg=None, average=True, show_sem=True,
|
||||
figsize=(12, 3), save_dir=None, filename="psd.png"):
|
||||
"""
|
||||
对比 MI vs Rest 的功率谱密度 PSD
|
||||
|
||||
MI_data, Rest_data: (n_trials, n_ch, n_samples)
|
||||
channels: 需要绘制的通道
|
||||
average: 是否对所有试次平均
|
||||
show_sem: 是否绘制 SEM 阴影
|
||||
"""
|
||||
|
||||
n_trials, n_ch, n_samples = MI_data.shape
|
||||
n_trials = min(len(MI_data), len(Rest_data))
|
||||
# assert Rest_data.shape == MI_data.shape, "MI 和 Rest 数据维度必须一致"
|
||||
|
||||
if nperseg is None:
|
||||
nperseg = fs # 每 1 秒窗长度
|
||||
|
||||
# 计算 MI PSD
|
||||
psd_MI_all = []
|
||||
for trial in range(n_trials):
|
||||
psd_trial = []
|
||||
for ch in range(n_ch):
|
||||
f, Pxx = welch(MI_data[trial, ch], fs=fs, nperseg=nperseg)
|
||||
psd_trial.append(Pxx)
|
||||
psd_MI_all.append(psd_trial)
|
||||
psd_MI_all = np.array(psd_MI_all)
|
||||
|
||||
# 计算 Rest PSD
|
||||
psd_Rest_all = []
|
||||
for trial in range(n_trials):
|
||||
psd_trial = []
|
||||
for ch in range(n_ch):
|
||||
_, Pxx = welch(Rest_data[trial, ch], fs=fs, nperseg=nperseg)
|
||||
psd_trial.append(Pxx)
|
||||
psd_Rest_all.append(psd_trial)
|
||||
psd_Rest_all = np.array(psd_Rest_all)
|
||||
|
||||
# ---- Plot ----
|
||||
fig, ax = plt.subplots(1, len(compare_names), figsize=figsize)
|
||||
if len(compare_names) == 1:
|
||||
ax = [ax]
|
||||
|
||||
for i, ch in enumerate(compare_names):
|
||||
ch_idx = ch_names.index(ch)
|
||||
psd_MI_ch = psd_MI_all[:, ch_idx, :]
|
||||
psd_Rest_ch = psd_Rest_all[:, ch_idx, :]
|
||||
|
||||
if average:
|
||||
mean_MI = psd_MI_ch.mean(axis=0)
|
||||
mean_Rest = psd_Rest_ch.mean(axis=0)
|
||||
|
||||
ax[i].plot(f, mean_MI, color='C0', label='MI')
|
||||
ax[i].plot(f, mean_Rest, color='C1', label='Rest')
|
||||
|
||||
if show_sem:
|
||||
ax[i].fill_between(f, mean_MI - sem(psd_MI_ch, axis=0),
|
||||
mean_MI + sem(psd_MI_ch, axis=0), color='C0', alpha=0.3)
|
||||
ax[i].fill_between(f, mean_Rest - sem(psd_Rest_ch, axis=0),
|
||||
mean_Rest + sem(psd_Rest_ch, axis=0), color='C1', alpha=0.3)
|
||||
else:
|
||||
ax[i].plot(f, psd_MI_ch.T, color='C0', alpha=0.3)
|
||||
ax[i].plot(f, psd_Rest_ch.T, color='C1', alpha=0.3)
|
||||
|
||||
ax[i].set_title(ch)
|
||||
ax[i].set_xlabel("Frequency (Hz)")
|
||||
ax[i].set_ylabel("PSD (μV²/Hz)")
|
||||
ax[i].grid(alpha=0.3)
|
||||
if i == 0:
|
||||
ax[i].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 将图像保存到内存字节流(PNG 格式)
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||
plt.close(fig) # 释放内存
|
||||
buf.seek(0)
|
||||
image_bytes = buf.read()
|
||||
buf.close()
|
||||
|
||||
return image_bytes
|
||||
|
||||
|
||||
def plotMain(
|
||||
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'],
|
||||
compare_names = [ 'C3','CZ','C4'],
|
||||
Data = None,labels = None,MI_label = None,Rest_label = None,
|
||||
fs = 250):
|
||||
|
||||
trial_idx = 0
|
||||
|
||||
# 数据划分
|
||||
if not MI_label:
|
||||
label_ = np.unique(labels)
|
||||
else:
|
||||
label_ = (MI_label,Rest_label)
|
||||
MI_data = Data[labels == label_[0]]
|
||||
Rest_data = Data[labels == label_[1]]
|
||||
|
||||
# 典型 EEG 频带
|
||||
FREQ_BANDS = {
|
||||
"Delta (0.8-4Hz)": (0.8, 4),
|
||||
"Theta (4-8Hz)": (4, 8),
|
||||
"Alpha (8-12Hz)": (8, 12),
|
||||
"Beta (12-30Hz)": (12, 30),
|
||||
"All (0.8-30Hz)": (0.8, 30)
|
||||
}
|
||||
# 利用welch估算PSD
|
||||
band_names, psd_MI, psd_Rest= compute_band_psd(
|
||||
eeg=Data,
|
||||
fs=fs,
|
||||
bands=FREQ_BANDS,
|
||||
labels=labels,
|
||||
trial_idx=trial_idx,
|
||||
MI_label=MI_label,
|
||||
Rest_label=Rest_label,
|
||||
avg= True
|
||||
)
|
||||
# 绘制地形图
|
||||
topomaps_imgBytes = plot_multiband_topomaps(
|
||||
ch_names=ch_names,
|
||||
psd_MI=psd_MI,
|
||||
psd_Rest=psd_Rest,
|
||||
bands=FREQ_BANDS
|
||||
)
|
||||
|
||||
# 绘制脑网络
|
||||
mi_plv_matrix = calculate_plv(MI_data[trial_idx])
|
||||
mi_BI_matrix = threshold_proportional(mi_plv_matrix, prop=0.3)
|
||||
rest_plv_matrix = calculate_plv(Rest_data[trial_idx])
|
||||
rest_BI_matrix = threshold_proportional(rest_plv_matrix, prop=0.3)
|
||||
network_imgBytes = plot_multiband_network(ch_names, mi_BI_matrix, rest_BI_matrix)
|
||||
|
||||
# ERDS 先计算erds,后平均
|
||||
MI_power = compute_power(MI_data)
|
||||
Rest_power = compute_power(Rest_data)
|
||||
erds_dict_MI, erds_dict_Rest = compute_all_erds(MI_power, Rest_power)
|
||||
erds_imgBytes = plot_compare_erds(erds_dict_MI, erds_dict_Rest, ch_names=ch_names,
|
||||
compare_names=compare_names, bands=['mu', 'beta'],
|
||||
fs=fs, mode="erds")
|
||||
|
||||
# 绘制PSD
|
||||
psd_imgBytes = plot_psd_compare(MI_data, Rest_data, ch_names = ch_names, compare_names=compare_names,
|
||||
fs=fs, nperseg=None, average=True, show_sem=True,
|
||||
figsize=(12, 3))
|
||||
return {'topomaps_imgBytes':base64.b64encode(topomaps_imgBytes).decode(),'network_imgBytes':base64.b64encode(network_imgBytes).decode(),
|
||||
'erds_imgBytes':base64.b64encode(erds_imgBytes).decode(),'psd_imgBytes':base64.b64encode(psd_imgBytes).decode()}
|
||||
|
||||
if __name__ == '__main__':
|
||||
allData = np.random.uniform(-50,50,size=(80,21,1000))
|
||||
allLabel = np.random.randint(1,3,size=(80,))
|
||||
allData = allData[:len(allLabel)]
|
||||
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
||||
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
||||
compare_names = ['C3', 'CZ', 'C4']
|
||||
ret = plotMain(ch_names=ch_names, compare_names=compare_names, Data=allData, labels=allLabel, MI_label=1, Rest_label=2,
|
||||
fs=250)
|
||||
print('计算完成,开始发送')
|
||||
from Zmq.zmqClient import zmqClient
|
||||
|
||||
zmqClient = zmqClient('192.168.76.101', 8088)
|
||||
zmqClient.connect()
|
||||
zmqClient.send_to_all('miReport', ret)
|
||||
Reference in New Issue
Block a user