Files
bci_algo/SSMVEP/algorithm/base.py
2026-06-05 09:34:29 +08:00

419 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2021/1/07
# License: MIT License
from typing import Optional, List, Tuple, Union
import warnings
import numpy as np
from numpy import ndarray
from numpy.linalg import linalg
from scipy.linalg import solve, qr
from scipy.signal import sosfiltfilt, cheby1, cheb1ord
from sklearn.base import BaseEstimator, TransformerMixin, clone
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
"""Transform spatial filters to spatial patterns based on paper [1]_.
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
information from different channels to extract signals of interest from EEG signals, but if our goal is
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
from the obtained spatial filters.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
W : ndarray
Spatial filters, shape (n_channels, n_filters).
Cx : ndarray
Covariance matrix of eeg data, shape (n_channels, n_channels).
Cs : ndarray
Covariance matrix of source data, shape (n_channels, n_channels).
Returns
-------
A : ndarray
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
References
----------
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
Neuroimage 87 (2014): 96-110.
"""
# use linalg.solve instead of inv, makes it more stable
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
A = solve(Cs.T, np.dot(Cx, W).T).T
return A
class FilterBank(BaseEstimator, TransformerMixin):
"""
Filter bank decomposition is a bandpass filter array that divides the input signal into
multiple subband components and obtains the eigenvalues of each subband component.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
base_estimator : class
Estimator for model training and feature extraction.
filterbank : list[ndarray]
A bandpass filter bank used to divide the input signal into multiple subband components.
n_jobs : int
Sets the number of CPU working cores. The default is None.
References
----------
.. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J].
Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067.
"""
def __init__(
self,
base_estimator: BaseEstimator,
filterbank: List[ndarray],
n_jobs: Optional[int] = None,
):
self.base_estimator = base_estimator
self.filterbank = filterbank
self.n_jobs = n_jobs
def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs):
"""
Training model
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : None
Training signal (parameters can be ignored, only used to maintain code structure).
y : None
Label data (ibid., ignorable).
Yf : None
Reference signal (ibid., ignorable).
"""
self.estimators_ = [
clone(self.base_estimator) for _ in range(len(self.filterbank))
]
X = self.transform_filterbank(X)
for i, est in enumerate(self.estimators_):
est.fit(X[i], y, **kwargs)
# def wrapper(est, X, y, kwargs):
# est.fit(X, y, **kwargs)
# return est
# self.estimators_ = Parallel(n_jobs=self.n_jobs)(
# delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_))
return self
def transform(self, X: ndarray, **kwargs):
"""
The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to
obtain the eigenvalues of each subband component.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : ndarray, shape(n_trials, n_channels, n_samples)
Test the signal.
Returns
-------
feat : ndarray, shape(n_trials, n_fre)
Feature array.
"""
X = self.transform_filterbank(X)
feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)]
# def wrapper(est, X, kwargs):
# retval = est.transform(X, **kwargs)
# return retval
# feat = Parallel(n_jobs=self.n_jobs)(
# delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_))
feat = np.concatenate(feat, axis=-1)
return feat
def transform_filterbank(self, X: ndarray):
"""
The input signal is filtered through a filter bank.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : ndarray, shape(n_trials, n_channels, n_samples)
Input signal.
Returns
-------
Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples)
Individual subband components of the input signal.
"""
Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank])
return Xs
class FilterBankSSVEP(FilterBank):
"""
Filter bank analysis for SSVEP.
The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal
into specific segments (subbands containing the original data) and obtain its characteristic data.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
filterbank : list[ndarray]
The filter bank.
base_estimator : class
Estimator for model training and feature extraction.
filterweights : ndarray
Filter weight, default is None.
n_jobs : int
Sets the number of CPU working cores. The default is None.
"""
def __init__(
self,
filterbank: List[ndarray],
base_estimator: BaseEstimator,
filterweights: Optional[ndarray] = None,
n_jobs: Optional[int] = None,
):
self.filterweights = filterweights
super().__init__(base_estimator, filterbank, n_jobs=n_jobs)
def transform(self, X: ndarray): # type: ignore[override]
"""
X is converted into features by using the parameters stored in self, and the eigenvalues of each subband
component are obtained after the input signal is filtered by the filter bank.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : ndarray, shape(n_trials, n_channels, n_samples)
Test the signal.
Returns
-------
features : ndarray, shape(n_trials, n_fre)
Feature array.
"""
features = super().transform(X)
if self.filterweights is None:
return features
else:
features = np.reshape(
features, (features.shape[0], len(self.filterbank), -1)
)
return np.sum(
features * self.filterweights[np.newaxis, :, np.newaxis], axis=1
)
def generate_filterbank(
passbands: List[Tuple[float, float]],
stopbands: List[Tuple[float, float]],
srate: int,
order: Optional[int] = None,
rp: float = 0.5,
):
"""
Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple
subband components.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
passbands : list or tuple(float, float)
Passband parameters.
stopbands : list or tuple(float, float)
Stopband parameters.
srate : float
Sampling rate.
order : int
Filter order.
rp : float
The maximum ripple allowed in the passband below the unit gain is 0.5 by default.
Returns
-------
Filterbankndarray, shape(len(passbands), N, 6)
Filter bank coefficient.
"""
filterbank = []
for wp, ws in zip(passbands, stopbands):
if order is None:
N, wn = cheb1ord(wp, ws, 3, 40, fs=srate)
sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate)
else:
sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate)
filterbank.append(sos)
return filterbank
def process(data):
# 白化操作
meanValue = np.mat(data.mean(axis=1))
meanData = np.repeat(meanValue, data.shape[1], axis=1)
whiteTemp = data - meanData
# QR 分解
rankWhiteTemp = whiteTemp.shape[0]
whiteTemp = np.transpose(whiteTemp)
Q, R = qr(whiteTemp.A, mode='economic')
# 计算矩阵的秩
rankQ = linalg.matrix_rank(R)
if rankQ == 0:
raise ValueError('stats:canoncorr:badData')
elif rankQ <= rankWhiteTemp:
# warnings.warn('stats:canoncorr:NotFullRank')
Q = Q[:, 0:rankQ]
return Q, rankQ
def reference(listFreqs,fs, numberSmples, num_harms):
numberFrequence = len(listFreqs)
timeIndex = np.arange(1, numberSmples + 1) / fs # time index
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
for frequenceIndex in range(numberFrequence):
temp = []
for harmIndex in range(1, num_harms + 1):
stimFrequence = listFreqs[frequenceIndex] # in HZ
# Sin and Cos
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
referenceTemp = np.mat(temp)
# 白化操作和QR分解
Q, rankQ = process(referenceTemp)
referenceData[frequenceIndex] = np.transpose(Q)
return referenceData
def generate_cca_references(
freqs: Union[ndarray, int, float],
srate,
T,
phases: Optional[Union[ndarray, int, float]] = None,
n_harmonics: int = 1,
):
"""
Construct a sine-cosine reference signal for canonical correlation analysis (CCA).
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
freqs : int or float
Frequency.
srate : int
Sampling rate.
T : int
Sampling time.
phases : int or float
Phase, default is None.
n_harmonics : int
The number of harmonics. The default value is 1.
Returns
-------
Yfndarray, shape(srate*T, n_harmonics*2)
Sine and cosine reference signal.
"""
if isinstance(freqs, int) or isinstance(freqs, float):
freqs = np.array([freqs])
freqs = np.array(freqs)[:, np.newaxis]
if phases is None:
phases = 0
if isinstance(phases, int) or isinstance(phases, float):
phases = np.array([phases])
phases = np.array(phases)[:, np.newaxis]
t = np.linspace(0, T, int(T * srate))
Yf = []
for i in range(n_harmonics):
Yf.append(
np.stack(
[
np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
],
axis=1,
)
)
Yf = np.concatenate(Yf, axis=1)
return Yf
def sign_flip(u, s, vh=None):
"""Flip signs of SVD or EIG using the method in paper [1]_.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
u: ndarray
left singular vectors, shape (M, K).
s: ndarray
singular values, shape (K,).
vh: ndarray or None
transpose of right singular vectors, shape (K, N).
Returns
-------
u: ndarray
corrected left singular vectors.
s: ndarray
singular values.
vh: ndarray
transpose of corrected right singular vectors.
References
----------
.. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf
"""
if vh is None:
total_proj = np.sum(u * s, axis=0)
signs = np.sign(total_proj)
random_idx = signs == 0
if np.any(random_idx):
signs[random_idx] = 1
warnings.warn(
"The magnitude is close to zero, the sign will become arbitrary."
)
u = u * signs
return u, s
else:
left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1)
right_proj = np.sum(u * s, axis=0)
total_proj = left_proj + right_proj
signs = np.sign(total_proj)
random_idx = signs == 0
if np.any(random_idx):
signs[random_idx] = 1
warnings.warn(
"The magnitude is close to zero, the sign will become arbitrary."
)
u = u * signs
vh = signs[:, np.newaxis] * vh
return u, s, vh