419 lines
13 KiB
Python
419 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
||
#
|
||
# Authors: Swolf <swolfforever@gmail.com>
|
||
# Date: 2021/1/07
|
||
# License: MIT License
|
||
|
||
|
||
from typing import Optional, List, Tuple, Union
|
||
import warnings
|
||
import numpy as np
|
||
from numpy import ndarray
|
||
from numpy.linalg import linalg
|
||
from scipy.linalg import solve, qr
|
||
from scipy.signal import sosfiltfilt, cheby1, cheb1ord
|
||
from sklearn.base import BaseEstimator, TransformerMixin, clone
|
||
|
||
|
||
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
|
||
"""Transform spatial filters to spatial patterns based on paper [1]_.
|
||
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
|
||
information from different channels to extract signals of interest from EEG signals, but if our goal is
|
||
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
|
||
from the obtained spatial filters.
|
||
|
||
update log:
|
||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||
|
||
Parameters
|
||
----------
|
||
W : ndarray
|
||
Spatial filters, shape (n_channels, n_filters).
|
||
Cx : ndarray
|
||
Covariance matrix of eeg data, shape (n_channels, n_channels).
|
||
Cs : ndarray
|
||
Covariance matrix of source data, shape (n_channels, n_channels).
|
||
|
||
Returns
|
||
-------
|
||
A : ndarray
|
||
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
|
||
|
||
References
|
||
----------
|
||
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
|
||
Neuroimage 87 (2014): 96-110.
|
||
"""
|
||
# use linalg.solve instead of inv, makes it more stable
|
||
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
|
||
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
|
||
A = solve(Cs.T, np.dot(Cx, W).T).T
|
||
return A
|
||
|
||
|
||
class FilterBank(BaseEstimator, TransformerMixin):
|
||
"""
|
||
Filter bank decomposition is a bandpass filter array that divides the input signal into
|
||
multiple subband components and obtains the eigenvalues of each subband component.
|
||
|
||
update log:
|
||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||
|
||
Parameters
|
||
----------
|
||
base_estimator : class
|
||
Estimator for model training and feature extraction.
|
||
filterbank : list[ndarray]
|
||
A bandpass filter bank used to divide the input signal into multiple subband components.
|
||
n_jobs : int
|
||
Sets the number of CPU working cores. The default is None.
|
||
|
||
References
|
||
----------
|
||
.. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J].
|
||
Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067.
|
||
"""
|
||
def __init__(
|
||
self,
|
||
base_estimator: BaseEstimator,
|
||
filterbank: List[ndarray],
|
||
n_jobs: Optional[int] = None,
|
||
):
|
||
self.base_estimator = base_estimator
|
||
self.filterbank = filterbank
|
||
self.n_jobs = n_jobs
|
||
|
||
def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs):
|
||
"""
|
||
Training model
|
||
|
||
update log:
|
||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||
|
||
Parameters
|
||
----------
|
||
X : None
|
||
Training signal (parameters can be ignored, only used to maintain code structure).
|
||
y : None
|
||
Label data (ibid., ignorable).
|
||
Yf : None
|
||
Reference signal (ibid., ignorable).
|
||
"""
|
||
self.estimators_ = [
|
||
clone(self.base_estimator) for _ in range(len(self.filterbank))
|
||
]
|
||
X = self.transform_filterbank(X)
|
||
for i, est in enumerate(self.estimators_):
|
||
est.fit(X[i], y, **kwargs)
|
||
# def wrapper(est, X, y, kwargs):
|
||
# est.fit(X, y, **kwargs)
|
||
# return est
|
||
# self.estimators_ = Parallel(n_jobs=self.n_jobs)(
|
||
# delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_))
|
||
return self
|
||
|
||
def transform(self, X: ndarray, **kwargs):
|
||
"""
|
||
The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to
|
||
obtain the eigenvalues of each subband component.
|
||
|
||
update log:
|
||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||
|
||
Parameters
|
||
----------
|
||
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||
Test the signal.
|
||
|
||
Returns
|
||
-------
|
||
feat : ndarray, shape(n_trials, n_fre)
|
||
Feature array.
|
||
"""
|
||
X = self.transform_filterbank(X)
|
||
feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)]
|
||
# def wrapper(est, X, kwargs):
|
||
# retval = est.transform(X, **kwargs)
|
||
# return retval
|
||
# feat = Parallel(n_jobs=self.n_jobs)(
|
||
# delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_))
|
||
feat = np.concatenate(feat, axis=-1)
|
||
return feat
|
||
|
||
def transform_filterbank(self, X: ndarray):
|
||
"""
|
||
The input signal is filtered through a filter bank.
|
||
|
||
update log:
|
||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||
|
||
Parameters
|
||
----------
|
||
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||
Input signal.
|
||
|
||
Returns
|
||
-------
|
||
Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples)
|
||
Individual subband components of the input signal.
|
||
"""
|
||
Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank])
|
||
return Xs
|
||
|
||
|
||
class FilterBankSSVEP(FilterBank):
|
||
"""
|
||
Filter bank analysis for SSVEP.
|
||
The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal
|
||
into specific segments (subbands containing the original data) and obtain its characteristic data.
|
||
|
||
update log:
|
||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||
|
||
Parameters
|
||
----------
|
||
filterbank : list[ndarray]
|
||
The filter bank.
|
||
base_estimator : class
|
||
Estimator for model training and feature extraction.
|
||
filterweights : ndarray
|
||
Filter weight, default is None.
|
||
n_jobs : int
|
||
Sets the number of CPU working cores. The default is None.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
filterbank: List[ndarray],
|
||
base_estimator: BaseEstimator,
|
||
filterweights: Optional[ndarray] = None,
|
||
n_jobs: Optional[int] = None,
|
||
):
|
||
self.filterweights = filterweights
|
||
super().__init__(base_estimator, filterbank, n_jobs=n_jobs)
|
||
|
||
def transform(self, X: ndarray): # type: ignore[override]
|
||
"""
|
||
X is converted into features by using the parameters stored in self, and the eigenvalues of each subband
|
||
component are obtained after the input signal is filtered by the filter bank.
|
||
|
||
update log:
|
||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||
|
||
Parameters
|
||
----------
|
||
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||
Test the signal.
|
||
|
||
Returns
|
||
-------
|
||
features : ndarray, shape(n_trials, n_fre)
|
||
Feature array.
|
||
"""
|
||
features = super().transform(X)
|
||
if self.filterweights is None:
|
||
return features
|
||
else:
|
||
features = np.reshape(
|
||
features, (features.shape[0], len(self.filterbank), -1)
|
||
)
|
||
return np.sum(
|
||
features * self.filterweights[np.newaxis, :, np.newaxis], axis=1
|
||
)
|
||
|
||
|
||
|
||
def generate_filterbank(
|
||
passbands: List[Tuple[float, float]],
|
||
stopbands: List[Tuple[float, float]],
|
||
srate: int,
|
||
order: Optional[int] = None,
|
||
rp: float = 0.5,
|
||
):
|
||
"""
|
||
Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple
|
||
subband components.
|
||
|
||
update log:
|
||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||
|
||
Parameters
|
||
----------
|
||
passbands : list or tuple(float, float)
|
||
Passband parameters.
|
||
stopbands : list or tuple(float, float)
|
||
Stopband parameters.
|
||
srate : float
|
||
Sampling rate.
|
||
order : int
|
||
Filter order.
|
||
rp : float
|
||
The maximum ripple allowed in the passband below the unit gain is 0.5 by default.
|
||
|
||
Returns
|
||
-------
|
||
Filterbank:ndarray, shape(len(passbands), N, 6)
|
||
Filter bank coefficient.
|
||
"""
|
||
filterbank = []
|
||
for wp, ws in zip(passbands, stopbands):
|
||
if order is None:
|
||
N, wn = cheb1ord(wp, ws, 3, 40, fs=srate)
|
||
sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate)
|
||
else:
|
||
sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate)
|
||
|
||
filterbank.append(sos)
|
||
return filterbank
|
||
|
||
def process(data):
|
||
# 白化操作
|
||
meanValue = np.mat(data.mean(axis=1))
|
||
meanData = np.repeat(meanValue, data.shape[1], axis=1)
|
||
whiteTemp = data - meanData
|
||
# QR 分解
|
||
rankWhiteTemp = whiteTemp.shape[0]
|
||
whiteTemp = np.transpose(whiteTemp)
|
||
Q, R = qr(whiteTemp.A, mode='economic')
|
||
# 计算矩阵的秩
|
||
rankQ = linalg.matrix_rank(R)
|
||
if rankQ == 0:
|
||
raise ValueError('stats:canoncorr:badData')
|
||
elif rankQ <= rankWhiteTemp:
|
||
# warnings.warn('stats:canoncorr:NotFullRank')
|
||
Q = Q[:, 0:rankQ]
|
||
return Q, rankQ
|
||
|
||
def reference(listFreqs,fs, numberSmples, num_harms):
|
||
numberFrequence = len(listFreqs)
|
||
timeIndex = np.arange(1, numberSmples + 1) / fs # time index
|
||
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
|
||
for frequenceIndex in range(numberFrequence):
|
||
temp = []
|
||
for harmIndex in range(1, num_harms + 1):
|
||
stimFrequence = listFreqs[frequenceIndex] # in HZ
|
||
# Sin and Cos
|
||
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
|
||
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
|
||
referenceTemp = np.mat(temp)
|
||
# 白化操作和QR分解
|
||
Q, rankQ = process(referenceTemp)
|
||
referenceData[frequenceIndex] = np.transpose(Q)
|
||
return referenceData
|
||
|
||
def generate_cca_references(
|
||
freqs: Union[ndarray, int, float],
|
||
srate,
|
||
T,
|
||
phases: Optional[Union[ndarray, int, float]] = None,
|
||
n_harmonics: int = 1,
|
||
):
|
||
"""
|
||
Construct a sine-cosine reference signal for canonical correlation analysis (CCA).
|
||
|
||
update log:
|
||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||
|
||
Parameters
|
||
----------
|
||
freqs : int or float
|
||
Frequency.
|
||
srate : int
|
||
Sampling rate.
|
||
T : int
|
||
Sampling time.
|
||
phases : int or float
|
||
Phase, default is None.
|
||
n_harmonics : int
|
||
The number of harmonics. The default value is 1.
|
||
|
||
Returns
|
||
-------
|
||
Yf:ndarray, shape(srate*T, n_harmonics*2)
|
||
Sine and cosine reference signal.
|
||
"""
|
||
if isinstance(freqs, int) or isinstance(freqs, float):
|
||
freqs = np.array([freqs])
|
||
freqs = np.array(freqs)[:, np.newaxis]
|
||
if phases is None:
|
||
phases = 0
|
||
if isinstance(phases, int) or isinstance(phases, float):
|
||
phases = np.array([phases])
|
||
phases = np.array(phases)[:, np.newaxis]
|
||
t = np.linspace(0, T, int(T * srate))
|
||
|
||
Yf = []
|
||
for i in range(n_harmonics):
|
||
Yf.append(
|
||
np.stack(
|
||
[
|
||
np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
|
||
np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
|
||
],
|
||
axis=1,
|
||
)
|
||
)
|
||
Yf = np.concatenate(Yf, axis=1)
|
||
return Yf
|
||
|
||
|
||
def sign_flip(u, s, vh=None):
|
||
"""Flip signs of SVD or EIG using the method in paper [1]_.
|
||
|
||
update log:
|
||
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||
|
||
Parameters
|
||
----------
|
||
u: ndarray
|
||
left singular vectors, shape (M, K).
|
||
s: ndarray
|
||
singular values, shape (K,).
|
||
vh: ndarray or None
|
||
transpose of right singular vectors, shape (K, N).
|
||
|
||
Returns
|
||
-------
|
||
u: ndarray
|
||
corrected left singular vectors.
|
||
s: ndarray
|
||
singular values.
|
||
vh: ndarray
|
||
transpose of corrected right singular vectors.
|
||
|
||
References
|
||
----------
|
||
.. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf
|
||
"""
|
||
if vh is None:
|
||
total_proj = np.sum(u * s, axis=0)
|
||
signs = np.sign(total_proj)
|
||
|
||
random_idx = signs == 0
|
||
if np.any(random_idx):
|
||
signs[random_idx] = 1
|
||
warnings.warn(
|
||
"The magnitude is close to zero, the sign will become arbitrary."
|
||
)
|
||
|
||
u = u * signs
|
||
|
||
return u, s
|
||
else:
|
||
left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1)
|
||
right_proj = np.sum(u * s, axis=0)
|
||
total_proj = left_proj + right_proj
|
||
signs = np.sign(total_proj)
|
||
|
||
random_idx = signs == 0
|
||
if np.any(random_idx):
|
||
signs[random_idx] = 1
|
||
warnings.warn(
|
||
"The magnitude is close to zero, the sign will become arbitrary."
|
||
)
|
||
|
||
u = u * signs
|
||
vh = signs[:, np.newaxis] * vh
|
||
|
||
return u, s, vh
|