初始化zmq 项目

This commit is contained in:
2026-06-05 09:34:29 +08:00
commit e6e4fb7da7
36 changed files with 6470 additions and 0 deletions

418
SSMVEP/algorithm/base.py Normal file
View File

@@ -0,0 +1,418 @@
# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2021/1/07
# License: MIT License
from typing import Optional, List, Tuple, Union
import warnings
import numpy as np
from numpy import ndarray
from numpy.linalg import linalg
from scipy.linalg import solve, qr
from scipy.signal import sosfiltfilt, cheby1, cheb1ord
from sklearn.base import BaseEstimator, TransformerMixin, clone
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
"""Transform spatial filters to spatial patterns based on paper [1]_.
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
information from different channels to extract signals of interest from EEG signals, but if our goal is
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
from the obtained spatial filters.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
W : ndarray
Spatial filters, shape (n_channels, n_filters).
Cx : ndarray
Covariance matrix of eeg data, shape (n_channels, n_channels).
Cs : ndarray
Covariance matrix of source data, shape (n_channels, n_channels).
Returns
-------
A : ndarray
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
References
----------
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
Neuroimage 87 (2014): 96-110.
"""
# use linalg.solve instead of inv, makes it more stable
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
A = solve(Cs.T, np.dot(Cx, W).T).T
return A
class FilterBank(BaseEstimator, TransformerMixin):
"""
Filter bank decomposition is a bandpass filter array that divides the input signal into
multiple subband components and obtains the eigenvalues of each subband component.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
base_estimator : class
Estimator for model training and feature extraction.
filterbank : list[ndarray]
A bandpass filter bank used to divide the input signal into multiple subband components.
n_jobs : int
Sets the number of CPU working cores. The default is None.
References
----------
.. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J].
Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067.
"""
def __init__(
self,
base_estimator: BaseEstimator,
filterbank: List[ndarray],
n_jobs: Optional[int] = None,
):
self.base_estimator = base_estimator
self.filterbank = filterbank
self.n_jobs = n_jobs
def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs):
"""
Training model
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : None
Training signal (parameters can be ignored, only used to maintain code structure).
y : None
Label data (ibid., ignorable).
Yf : None
Reference signal (ibid., ignorable).
"""
self.estimators_ = [
clone(self.base_estimator) for _ in range(len(self.filterbank))
]
X = self.transform_filterbank(X)
for i, est in enumerate(self.estimators_):
est.fit(X[i], y, **kwargs)
# def wrapper(est, X, y, kwargs):
# est.fit(X, y, **kwargs)
# return est
# self.estimators_ = Parallel(n_jobs=self.n_jobs)(
# delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_))
return self
def transform(self, X: ndarray, **kwargs):
"""
The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to
obtain the eigenvalues of each subband component.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : ndarray, shape(n_trials, n_channels, n_samples)
Test the signal.
Returns
-------
feat : ndarray, shape(n_trials, n_fre)
Feature array.
"""
X = self.transform_filterbank(X)
feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)]
# def wrapper(est, X, kwargs):
# retval = est.transform(X, **kwargs)
# return retval
# feat = Parallel(n_jobs=self.n_jobs)(
# delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_))
feat = np.concatenate(feat, axis=-1)
return feat
def transform_filterbank(self, X: ndarray):
"""
The input signal is filtered through a filter bank.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : ndarray, shape(n_trials, n_channels, n_samples)
Input signal.
Returns
-------
Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples)
Individual subband components of the input signal.
"""
Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank])
return Xs
class FilterBankSSVEP(FilterBank):
"""
Filter bank analysis for SSVEP.
The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal
into specific segments (subbands containing the original data) and obtain its characteristic data.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
filterbank : list[ndarray]
The filter bank.
base_estimator : class
Estimator for model training and feature extraction.
filterweights : ndarray
Filter weight, default is None.
n_jobs : int
Sets the number of CPU working cores. The default is None.
"""
def __init__(
self,
filterbank: List[ndarray],
base_estimator: BaseEstimator,
filterweights: Optional[ndarray] = None,
n_jobs: Optional[int] = None,
):
self.filterweights = filterweights
super().__init__(base_estimator, filterbank, n_jobs=n_jobs)
def transform(self, X: ndarray): # type: ignore[override]
"""
X is converted into features by using the parameters stored in self, and the eigenvalues of each subband
component are obtained after the input signal is filtered by the filter bank.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
X : ndarray, shape(n_trials, n_channels, n_samples)
Test the signal.
Returns
-------
features : ndarray, shape(n_trials, n_fre)
Feature array.
"""
features = super().transform(X)
if self.filterweights is None:
return features
else:
features = np.reshape(
features, (features.shape[0], len(self.filterbank), -1)
)
return np.sum(
features * self.filterweights[np.newaxis, :, np.newaxis], axis=1
)
def generate_filterbank(
passbands: List[Tuple[float, float]],
stopbands: List[Tuple[float, float]],
srate: int,
order: Optional[int] = None,
rp: float = 0.5,
):
"""
Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple
subband components.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
passbands : list or tuple(float, float)
Passband parameters.
stopbands : list or tuple(float, float)
Stopband parameters.
srate : float
Sampling rate.
order : int
Filter order.
rp : float
The maximum ripple allowed in the passband below the unit gain is 0.5 by default.
Returns
-------
Filterbankndarray, shape(len(passbands), N, 6)
Filter bank coefficient.
"""
filterbank = []
for wp, ws in zip(passbands, stopbands):
if order is None:
N, wn = cheb1ord(wp, ws, 3, 40, fs=srate)
sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate)
else:
sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate)
filterbank.append(sos)
return filterbank
def process(data):
# 白化操作
meanValue = np.mat(data.mean(axis=1))
meanData = np.repeat(meanValue, data.shape[1], axis=1)
whiteTemp = data - meanData
# QR 分解
rankWhiteTemp = whiteTemp.shape[0]
whiteTemp = np.transpose(whiteTemp)
Q, R = qr(whiteTemp.A, mode='economic')
# 计算矩阵的秩
rankQ = linalg.matrix_rank(R)
if rankQ == 0:
raise ValueError('stats:canoncorr:badData')
elif rankQ <= rankWhiteTemp:
# warnings.warn('stats:canoncorr:NotFullRank')
Q = Q[:, 0:rankQ]
return Q, rankQ
def reference(listFreqs,fs, numberSmples, num_harms):
numberFrequence = len(listFreqs)
timeIndex = np.arange(1, numberSmples + 1) / fs # time index
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
for frequenceIndex in range(numberFrequence):
temp = []
for harmIndex in range(1, num_harms + 1):
stimFrequence = listFreqs[frequenceIndex] # in HZ
# Sin and Cos
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
referenceTemp = np.mat(temp)
# 白化操作和QR分解
Q, rankQ = process(referenceTemp)
referenceData[frequenceIndex] = np.transpose(Q)
return referenceData
def generate_cca_references(
freqs: Union[ndarray, int, float],
srate,
T,
phases: Optional[Union[ndarray, int, float]] = None,
n_harmonics: int = 1,
):
"""
Construct a sine-cosine reference signal for canonical correlation analysis (CCA).
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
freqs : int or float
Frequency.
srate : int
Sampling rate.
T : int
Sampling time.
phases : int or float
Phase, default is None.
n_harmonics : int
The number of harmonics. The default value is 1.
Returns
-------
Yfndarray, shape(srate*T, n_harmonics*2)
Sine and cosine reference signal.
"""
if isinstance(freqs, int) or isinstance(freqs, float):
freqs = np.array([freqs])
freqs = np.array(freqs)[:, np.newaxis]
if phases is None:
phases = 0
if isinstance(phases, int) or isinstance(phases, float):
phases = np.array([phases])
phases = np.array(phases)[:, np.newaxis]
t = np.linspace(0, T, int(T * srate))
Yf = []
for i in range(n_harmonics):
Yf.append(
np.stack(
[
np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
],
axis=1,
)
)
Yf = np.concatenate(Yf, axis=1)
return Yf
def sign_flip(u, s, vh=None):
"""Flip signs of SVD or EIG using the method in paper [1]_.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
u: ndarray
left singular vectors, shape (M, K).
s: ndarray
singular values, shape (K,).
vh: ndarray or None
transpose of right singular vectors, shape (K, N).
Returns
-------
u: ndarray
corrected left singular vectors.
s: ndarray
singular values.
vh: ndarray
transpose of corrected right singular vectors.
References
----------
.. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf
"""
if vh is None:
total_proj = np.sum(u * s, axis=0)
signs = np.sign(total_proj)
random_idx = signs == 0
if np.any(random_idx):
signs[random_idx] = 1
warnings.warn(
"The magnitude is close to zero, the sign will become arbitrary."
)
u = u * signs
return u, s
else:
left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1)
right_proj = np.sum(u * s, axis=0)
total_proj = left_proj + right_proj
signs = np.sign(total_proj)
random_idx = signs == 0
if np.any(random_idx):
signs[random_idx] = 1
warnings.warn(
"The magnitude is close to zero, the sign will become arbitrary."
)
u = u * signs
vh = signs[:, np.newaxis] * vh
return u, s, vh

436
SSMVEP/algorithm/dsp.py Normal file
View File

@@ -0,0 +1,436 @@
# -*- coding: utf-8 -*-
# DSP: Discriminal Spatial Patterns
# Authors: Swolf <swolfforever@gmail.com>
# Junyang Wang <2144755928@qq.com>
# Last update date: 2022-8-11
# License: MIT License
from typing import Optional, List, Tuple
from itertools import combinations
import numpy as np
from scipy.linalg import eigh
from numpy import ndarray
from scipy.linalg import solve
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
"""Transform spatial filters to spatial patterns based on paper [1]_.
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
information from different channels to extract signals of interest from EEG signals, but if our goal is
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
from the obtained spatial filters.
update log:
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
Parameters
----------
W : ndarray
Spatial filters, shape (n_channels, n_filters).
Cx : ndarray
Covariance matrix of eeg data, shape (n_channels, n_channels).
Cs : ndarray
Covariance matrix of source data, shape (n_channels, n_channels).
Returns
-------
A : ndarray
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
References
----------
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
Neuroimage 87 (2014): 96-110.
"""
# use linalg.solve instead of inv, makes it more stable
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
A = solve(Cs.T, np.dot(Cx, W).T).T
return A
def isPD(B: ndarray) -> bool:
"""Returns true when input matrix is positive-definite, via Cholesky decompositon method.
Parameters
----------
B : ndarray
Any matrix, shape (N, N)
Returns
-------
bool
True if B is positve-definite.
Notes
-----
Use numpy.linalg rather than scipy.linalg. In this case, scipy.linalg has unpredictable behaviors.
"""
try:
_ = np.linalg.cholesky(B)
return True
except np.linalg.LinAlgError:
return False
def nearestPD(A: ndarray) -> ndarray:
"""Find the nearest positive-definite matrix to input.
Parameters
----------
A : ndarray
Any square matrxi, shape (N, N)
Returns
-------
A3 : ndarray
positive-definite matrix to A
Notes
-----
A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1]_, which
origins at [2]_.
References
----------
.. [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd
.. [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite matrix" (1988):
https://doi.org/10.1016/0024-3795(88)90223-6
"""
B = (A + A.T) / 2
_, s, V = np.linalg.svd(B)
H = np.dot(V.T, np.dot(np.diag(s), V))
A2 = (B + H) / 2
A3 = (A2 + A2.T) / 2
if isPD(A3):
return A3
print("Replace current matrix with the nearest positive-definite matrix.")
spacing = np.spacing(np.linalg.norm(A))
# The above is different from [1]. It appears that MATLAB's `chol` Cholesky
# decomposition will accept matrixes with exactly 0-eigenvalue, whereas
# Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab
# for `numpy.spacing`), we use the above definition. CAVEAT: our `spacing`
# will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on
# the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas
# `spacing` will, for Gaussian random matrixes of small dimension, be on
# othe order of 1e-16. In practice, both ways converge, as the unit test
# below suggests.
eye = np.eye(A.shape[0])
k = 1
while not isPD(A3):
mineig = np.min(np.real(np.linalg.eigvals(A3)))
A3 += eye * (-mineig * k**2 + spacing)
k += 1
return A3
def xiang_dsp_kernel(
X: ndarray, y: ndarray
) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
"""
DSP: Discriminal Spatial Patterns, only for two classes[1]_.
Import train data to solve spatial filters with DSP,
finds a projection matrix that maximize the between-class scatter matrix and
minimize the within-class scatter matrix. Currently only support for two types of data.
Author: Swolf <swolfforever@gmail.com>
Created on: 2021-1-07
Update log:
Parameters
----------
X : ndarray
EEG train data assuming removing mean, shape (n_trials, n_channels, n_samples)
y : ndarray
labels of EEG data, shape (n_trials, )
Returns
-------
W : ndarray
spatial filters, shape (n_channels, n_filters)
D : ndarray
eigenvalues in descending order
M : ndarray
mean value of all classes and trials, i.e. common mode signals, shape (n_channel, n_samples)
A : ndarray
spatial patterns, shape (n_channels, n_filters)
Notes
-----
the implementation removes regularization on within-class scatter matrix Sw.
References
----------
.. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in
a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831.
"""
X, y = np.copy(X), np.copy(y)
labels = np.unique(y)
X = np.reshape(X, (-1, *X.shape[-2:]))
X = X - np.mean(X, axis=-1, keepdims=True)
# the number of each label
n_labels = np.array([np.sum(y == label) for label in labels])
# average template of all trials
M = np.mean(X, axis=0)
# class conditional template
Ms, Ss = zip(
*[
(
np.mean(X[y == label], axis=0),
np.sum(
np.matmul(X[y == label], np.swapaxes(X[y == label], -1, -2)), axis=0
),
)
for label in labels
]
)
Ms, Ss = np.stack(Ms), np.stack(Ss)
# within-class scatter matrix
Sw = np.sum(
Ss
- n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)),
axis=0,
)
Ms = Ms - M
# between-class scatter matrix
Sb = np.sum(
n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)),
axis=0,
)
D, W = eigh(nearestPD(Sb), nearestPD(Sw))
ix = np.argsort(D)[::-1] # in descending order
D, W = D[ix], W[:, ix]
A = robust_pattern(W, Sb, W.T @ Sb @ W)
return W, D, M, A
def xiang_dsp_feature(
W: ndarray, M: ndarray, X: ndarray, n_components: int = 1
) -> ndarray:
"""
Return DSP features in paper [1]_.
Author: Swolf <swolfforever@gmail.com>
Created on: 2021-1-07
Update log:
Parameters
----------
W : ndarray
spatial filters from csp_kernel, shape (n_channels, n_filters)
M : ndarray
common template for all classes, shape (n_channel, n_samples)
X : ndarray
eeg test data, shape (n_trials, n_channels, n_samples)
n_components : int, optional
length of the spatial filters, first k components to use, by default 1
Returns
-------
features: ndarray
features, shape (n_trials, n_components, n_samples)
Raises
------
ValueError
n_components should less than half of the number of channels
Notes
-----
1. instead of meaning of filtered signals in paper [1]_., we directly return filtered signals.
References
----------
.. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in
a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831.
"""
W, M, X = np.copy(W), np.copy(M), np.copy(X)
max_components = W.shape[1]
if n_components > max_components:
raise ValueError("n_components should less than the number of channels")
X = np.reshape(X, (-1, *X.shape[-2:]))
X = X - np.mean(X, axis=-1, keepdims=True)
# print('************: ',np.shape(W),np.shape(X),np.shape(M))
features = np.matmul(W[:, :n_components].T, X - M)
return features
class DSP(BaseEstimator, TransformerMixin, ClassifierMixin):
"""
DSP: Discriminal Spatial Patterns
Author: Swolf <swolfforever@gmail.com>
Created on: 2021-1-07
Update log:
Parameters
----------
n_components : int
length of the spatial filter, first k components to use, by default 1
transform_method : str
method of template matching, by default corr (pearson correlation coefficient)
classes_ : int
number of the EEG classes
Attributes
----------
n_components : int
length of the spatial filter, first k components to use, by default 1
transform_method : str
method of template matching, by default corr (pearson correlation coefficient)
classes_ : int
number of the EEG classes
W_ : ndarray, shape(n_channels, n_filters)
Spatial filters, shape(n_channels, n_filters), in which n_channels = n_filters
D_ : ndarray, shape(n_filters )
eigenvalues in descending order, shape(n_filters, )
M_ : ndarray, shape(n_channels, n_samples)
mean value of all classes and trials, i.e. common mode signals, shape(n_channels, n_samples)
A_ : ndarray, shape(n_channels, n_filters)
spatial patterns, shape(n_channels, n_filters)
templates_: ndarray, shape(n_classes, n_filters, n_samples)
templates of train data, shape(n_classes, n_filters, n_samples)
"""
def __init__(self, n_components: int = 1, transform_method: str = "corr"):
self.n_components = n_components
self.transform_method = transform_method
def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None):
"""
Import the train data to get a model.
Parameters
----------
X : ndarray
train data, shape(n_trials, n_channels, n_samples)
y : ndarray
labels of train data, shape (n_trials, )
Yf : ndarray
optional parameter
Returns
-------
W_ : ndarray
spatial filters, shape (n_channels, n_filters), in which n_channels = n_filters
D_ : ndarray
eigenvalues in descending order, shape (n_filters, )
M_ : ndarray
template for all classes, shape (n_channel, n_samples)
A_ : ndarray
spatial patterns, shape (n_channels, n_filters)
templates_ : ndarray
templates of train data, shape (n_channels, n_filters, n_samples)
"""
X -= np.mean(X, axis=-1, keepdims=True)
self.classes_ = np.unique(y)
self.W_, self.D_, self.M_, self.A_ = xiang_dsp_kernel(X, y)
self.templates_ = np.stack(
[
np.mean(
xiang_dsp_feature(
self.W_, self.M_, X[y == label], n_components=self.W_.shape[1]
),
axis=0,
)
for label in self.classes_
]
)
return self
def transform(self, X: ndarray):
"""
Import the test data to get features.
Parameters
----------
X : ndarray
test data, shape(n_trials, n_channels, n_samples)
Returns
-------
feature : ndarray, shape(n_trials,n_classes)
correlation coefficients of templates of train data and features of test data, shape(n_trials, n_classes)
"""
n_components = self.n_components
X -= np.mean(X, axis=-1, keepdims=True)
features = xiang_dsp_feature(self.W_, self.M_, X, n_components=n_components)
if self.transform_method is None:
return features.reshape((features.shape[0], -1))
elif self.transform_method == "mean":
return np.mean(features, axis=-1)
elif self.transform_method == "corr":
return self._pearson_features(
features, self.templates_[:, :n_components, :]
)
else:
raise ValueError("non-supported transform method")
def _pearson_features(self, X: ndarray, templates: ndarray):
"""
Calculate pearson correlation coefficient.
Parameters
----------
X : ndarray
features of test data after spatial filters, shape(n_trials, n_components, n_samples)
templates : ndarray
templates of train data, shape(n_classes, n_components, n_samples)
Returns
-------
corr : ndarray
pearson correlation coefficient, shape(n_trials, n_classes)
"""
X = np.reshape(X, (-1, *X.shape[-2:]))
templates = np.reshape(templates, (-1, *templates.shape[-2:]))
X = X - np.mean(X, axis=-1, keepdims=True)
templates = templates - np.mean(templates, axis=-1, keepdims=True)
X = np.reshape(X, (X.shape[0], -1))
templates = np.reshape(templates, (templates.shape[0], -1))
istd_X = 1 / np.std(X, axis=-1, keepdims=True)
istd_templates = 1 / np.std(templates, axis=-1, keepdims=True)
corr = (X @ templates.T) / (templates.shape[1] - 1)
corr = istd_X * corr * istd_templates.T
return corr
def predict(self, X: ndarray):
"""
Import the templates and the test data to get prediction labels.
Parameters
----------
X : ndarray
test data, shape(n_trials, n_channels, n_samples)
Returns
-------
labels : ndarray
prediction labels of test data, shape(n_trials,)
"""
feat = self.transform(X)
if self.transform_method == "corr":
labels = self.classes_[np.argmax(feat, axis=-1)]
else:
raise NotImplementedError()
return labels

175
SSMVEP/algorithm/tdca.py Normal file
View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2021/10/10
# License: MIT License
"""
Task Decomposition Component Analysis.
"""
from typing import List
import numpy as np
from scipy.linalg import qr
from scipy.stats import pearsonr
from numpy import ndarray
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
from typing import Optional, List
from SSMVEP.algorithm.base import FilterBankSSVEP
from SSMVEP.algorithm.dsp import xiang_dsp_kernel, xiang_dsp_feature
def proj_ref(Yf: ndarray):
Q, R = qr(Yf.T, mode="economic")
P = Q @ Q.T
return P
def aug_2(X: ndarray, n_samples: int, padding_len: int, P: ndarray, training: bool = True):
X = X.reshape((-1, *X.shape[-2:]))
n_trials, n_channels, n_points = X.shape
# if n_points < padding_len + n_samples:
# raise ValueError("the length of X should be larger than l+n_samples.")
aug_X = np.zeros((n_trials, (padding_len + 1) * n_channels, n_samples))
if training:
for i in range(padding_len + 1):
aug_X[:, i * n_channels : (i + 1) * n_channels, :] = X[
..., i : i + n_samples
]
else:
for i in range(padding_len + 1):
aug_X[:, i * n_channels : (i + 1) * n_channels, : n_samples - i] = X[
..., i:n_samples
]
aug_Xp = aug_X @ P
aug_X = np.concatenate([aug_X, aug_Xp], axis=-1)
return aug_X
def tdca_feature(
X: ndarray,
templates: ndarray,
W: ndarray,
M: ndarray,
Ps: List[ndarray],
padding_len: int,
n_components: int = 1,
training=False,
):
rhos = []
for Xk, P in zip(templates, Ps):
a = xiang_dsp_feature(
W,
M,
aug_2(X, P.shape[0], padding_len, P, training=training),
n_components=n_components,
)
b = Xk[:n_components, :]
a = np.reshape(a, (-1))
b = np.reshape(b, (-1))
rhos.append(pearsonr(a, b)[0])
return rhos
class TDCA(BaseEstimator, TransformerMixin, ClassifierMixin):
def __init__(self, padding_len: int, n_components: int = 1):
self.padding_len = padding_len
self.n_components = n_components
def fit(self, X: ndarray, y: ndarray, Yf: ndarray):
X -= np.mean(X, axis=-1, keepdims=True)
self.classes_ = np.unique(y)
self.Ps_ = [proj_ref(Yf[i]) for i in range(len(self.classes_))]
# print(np.shape(self.Ps_))
aug_X_list, aug_Y_list = [], []
for i, label in enumerate(self.classes_):
aug_X_list.append(
aug_2(
X[y == label],
self.Ps_[i].shape[0],
self.padding_len,
self.Ps_[i],
training=True,
)
)
aug_Y_list.append(y[y == label])
aug_X = np.concatenate(aug_X_list, axis=0)
aug_Y = np.concatenate(aug_Y_list, axis=0)
self.W_, _, self.M_, _ = xiang_dsp_kernel(aug_X, aug_Y)
self.templates_ = np.stack(
[
np.mean(
xiang_dsp_feature(
self.W_,
self.M_,
aug_X[aug_Y == label],
n_components=self.W_.shape[1],
),
axis=0,
)
for label in self.classes_
]
)
return self
def transform(self, X: ndarray):
n_components = self.n_components
X -= np.mean(X, axis=-1, keepdims=True)
X = X.reshape((-1, *X.shape[-2:]))
rhos = [
tdca_feature(
tmp,
self.templates_,
self.W_,
self.M_,
self.Ps_,
self.padding_len,
n_components=n_components,
)
for tmp in X
]
rhos = np.stack(rhos)
return rhos
def predict(self, X: ndarray):
feat = self.transform(X)
labels = self.classes_[np.argmax(feat, axis=-1)]
return labels,feat
class FBTDCA(FilterBankSSVEP, ClassifierMixin):
def __init__(
self,
filterbank: List[ndarray],
padding_len: int,
n_components: int = 1,
filterweights: Optional[ndarray] = None,
n_jobs: Optional[int] = None,
):
self.padding_len = padding_len
self.n_components = n_components
self.filterweights = filterweights
self.n_jobs = n_jobs
super().__init__(
filterbank,
TDCA(padding_len, n_components=n_components),
filterweights=filterweights,
n_jobs=n_jobs,
)
def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None): # type: ignore[override]
self.classes_ = np.unique(y)
super().fit(X, y, Yf=Yf)
return self
def predict(self, X: ndarray):
features = self.transform(X)
if self.filterweights is None:
features = np.reshape(
features, (features.shape[0], len(self.filterbank), -1)
)
features = np.mean(features, axis=1)
labels = self.classes_[np.argmax(features, axis=-1)]
return labels,features