# -*- coding: utf-8 -*- # # Authors: Swolf # Date: 2021/1/07 # License: MIT License from typing import Optional, List, Tuple, Union import warnings import numpy as np from numpy import ndarray from numpy.linalg import linalg from scipy.linalg import solve, qr from scipy.signal import sosfiltfilt, cheby1, cheb1ord from sklearn.base import BaseEstimator, TransformerMixin, clone def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray: """Transform spatial filters to spatial patterns based on paper [1]_. Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine information from different channels to extract signals of interest from EEG signals, but if our goal is neurophysiological interpretation or visualization of weights, activation patterns need to be constructed from the obtained spatial filters. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- W : ndarray Spatial filters, shape (n_channels, n_filters). Cx : ndarray Covariance matrix of eeg data, shape (n_channels, n_channels). Cs : ndarray Covariance matrix of source data, shape (n_channels, n_channels). Returns ------- A : ndarray Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern. References ---------- .. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging. Neuroimage 87 (2014): 96-110. """ # use linalg.solve instead of inv, makes it more stable # see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py # and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html A = solve(Cs.T, np.dot(Cx, W).T).T return A class FilterBank(BaseEstimator, TransformerMixin): """ Filter bank decomposition is a bandpass filter array that divides the input signal into multiple subband components and obtains the eigenvalues of each subband component. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- base_estimator : class Estimator for model training and feature extraction. filterbank : list[ndarray] A bandpass filter bank used to divide the input signal into multiple subband components. n_jobs : int Sets the number of CPU working cores. The default is None. References ---------- .. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J]. Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067. """ def __init__( self, base_estimator: BaseEstimator, filterbank: List[ndarray], n_jobs: Optional[int] = None, ): self.base_estimator = base_estimator self.filterbank = filterbank self.n_jobs = n_jobs def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs): """ Training model update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- X : None Training signal (parameters can be ignored, only used to maintain code structure). y : None Label data (ibid., ignorable). Yf : None Reference signal (ibid., ignorable). """ self.estimators_ = [ clone(self.base_estimator) for _ in range(len(self.filterbank)) ] X = self.transform_filterbank(X) for i, est in enumerate(self.estimators_): est.fit(X[i], y, **kwargs) # def wrapper(est, X, y, kwargs): # est.fit(X, y, **kwargs) # return est # self.estimators_ = Parallel(n_jobs=self.n_jobs)( # delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_)) return self def transform(self, X: ndarray, **kwargs): """ The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to obtain the eigenvalues of each subband component. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- X : ndarray, shape(n_trials, n_channels, n_samples) Test the signal. Returns ------- feat : ndarray, shape(n_trials, n_fre) Feature array. """ X = self.transform_filterbank(X) feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)] # def wrapper(est, X, kwargs): # retval = est.transform(X, **kwargs) # return retval # feat = Parallel(n_jobs=self.n_jobs)( # delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_)) feat = np.concatenate(feat, axis=-1) return feat def transform_filterbank(self, X: ndarray): """ The input signal is filtered through a filter bank. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- X : ndarray, shape(n_trials, n_channels, n_samples) Input signal. Returns ------- Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples) Individual subband components of the input signal. """ Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank]) return Xs class FilterBankSSVEP(FilterBank): """ Filter bank analysis for SSVEP. The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal into specific segments (subbands containing the original data) and obtain its characteristic data. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- filterbank : list[ndarray] The filter bank. base_estimator : class Estimator for model training and feature extraction. filterweights : ndarray Filter weight, default is None. n_jobs : int Sets the number of CPU working cores. The default is None. """ def __init__( self, filterbank: List[ndarray], base_estimator: BaseEstimator, filterweights: Optional[ndarray] = None, n_jobs: Optional[int] = None, ): self.filterweights = filterweights super().__init__(base_estimator, filterbank, n_jobs=n_jobs) def transform(self, X: ndarray): # type: ignore[override] """ X is converted into features by using the parameters stored in self, and the eigenvalues of each subband component are obtained after the input signal is filtered by the filter bank. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- X : ndarray, shape(n_trials, n_channels, n_samples) Test the signal. Returns ------- features : ndarray, shape(n_trials, n_fre) Feature array. """ features = super().transform(X) if self.filterweights is None: return features else: features = np.reshape( features, (features.shape[0], len(self.filterbank), -1) ) return np.sum( features * self.filterweights[np.newaxis, :, np.newaxis], axis=1 ) def generate_filterbank( passbands: List[Tuple[float, float]], stopbands: List[Tuple[float, float]], srate: int, order: Optional[int] = None, rp: float = 0.5, ): """ Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple subband components. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- passbands : list or tuple(float, float) Passband parameters. stopbands : list or tuple(float, float) Stopband parameters. srate : float Sampling rate. order : int Filter order. rp : float The maximum ripple allowed in the passband below the unit gain is 0.5 by default. Returns ------- Filterbank:ndarray, shape(len(passbands), N, 6) Filter bank coefficient. """ filterbank = [] for wp, ws in zip(passbands, stopbands): if order is None: N, wn = cheb1ord(wp, ws, 3, 40, fs=srate) sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate) else: sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate) filterbank.append(sos) return filterbank def process(data): # 白化操作 meanValue = np.mat(data.mean(axis=1)) meanData = np.repeat(meanValue, data.shape[1], axis=1) whiteTemp = data - meanData # QR 分解 rankWhiteTemp = whiteTemp.shape[0] whiteTemp = np.transpose(whiteTemp) Q, R = qr(whiteTemp.A, mode='economic') # 计算矩阵的秩 rankQ = linalg.matrix_rank(R) if rankQ == 0: raise ValueError('stats:canoncorr:badData') elif rankQ <= rankWhiteTemp: # warnings.warn('stats:canoncorr:NotFullRank') Q = Q[:, 0:rankQ] return Q, rankQ def reference(listFreqs,fs, numberSmples, num_harms): numberFrequence = len(listFreqs) timeIndex = np.arange(1, numberSmples + 1) / fs # time index referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples)) for frequenceIndex in range(numberFrequence): temp = [] for harmIndex in range(1, num_harms + 1): stimFrequence = listFreqs[frequenceIndex] # in HZ # Sin and Cos temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence), np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)]) referenceTemp = np.mat(temp) # 白化操作和QR分解 Q, rankQ = process(referenceTemp) referenceData[frequenceIndex] = np.transpose(Q) return referenceData def generate_cca_references( freqs: Union[ndarray, int, float], srate, T, phases: Optional[Union[ndarray, int, float]] = None, n_harmonics: int = 1, ): """ Construct a sine-cosine reference signal for canonical correlation analysis (CCA). update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- freqs : int or float Frequency. srate : int Sampling rate. T : int Sampling time. phases : int or float Phase, default is None. n_harmonics : int The number of harmonics. The default value is 1. Returns ------- Yf:ndarray, shape(srate*T, n_harmonics*2) Sine and cosine reference signal. """ if isinstance(freqs, int) or isinstance(freqs, float): freqs = np.array([freqs]) freqs = np.array(freqs)[:, np.newaxis] if phases is None: phases = 0 if isinstance(phases, int) or isinstance(phases, float): phases = np.array([phases]) phases = np.array(phases)[:, np.newaxis] t = np.linspace(0, T, int(T * srate)) Yf = [] for i in range(n_harmonics): Yf.append( np.stack( [ np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases), np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases), ], axis=1, ) ) Yf = np.concatenate(Yf, axis=1) return Yf def sign_flip(u, s, vh=None): """Flip signs of SVD or EIG using the method in paper [1]_. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- u: ndarray left singular vectors, shape (M, K). s: ndarray singular values, shape (K,). vh: ndarray or None transpose of right singular vectors, shape (K, N). Returns ------- u: ndarray corrected left singular vectors. s: ndarray singular values. vh: ndarray transpose of corrected right singular vectors. References ---------- .. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf """ if vh is None: total_proj = np.sum(u * s, axis=0) signs = np.sign(total_proj) random_idx = signs == 0 if np.any(random_idx): signs[random_idx] = 1 warnings.warn( "The magnitude is close to zero, the sign will become arbitrary." ) u = u * signs return u, s else: left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1) right_proj = np.sum(u * s, axis=0) total_proj = left_proj + right_proj signs = np.sign(total_proj) random_idx = signs == 0 if np.any(random_idx): signs[random_idx] = 1 warnings.warn( "The magnitude is close to zero, the sign will become arbitrary." ) u = u * signs vh = signs[:, np.newaxis] * vh return u, s, vh