2026-06-05 09:34:29 +08:00
"""
EEG Conformer
Convolutional Transformer for EEG decoding
Couple CNN and Transformer in a concise manner with amazing results
"""
# remember to change paths
import os
gpus = [ 0 ]
os . environ [ ' CUDA_DEVICE_ORDER ' ] = ' PCI_BUS_ID '
os . environ [ " CUDA_VISIBLE_DEVICES " ] = ' , ' . join ( map ( str , gpus ) )
import numpy as np
import math
import random
import time
import datetime
from torch . utils . data import DataLoader
from torch . autograd import Variable
import torch
import torch . nn . functional as F
from torch import nn
from torch import Tensor
from einops import rearrange
from einops . layers . torch import Rearrange , Reduce
# from common_spatial_pattern import csp
# from torch.utils.tensorboard import SummaryWriter
from torch . backends import cudnn
cudnn . benchmark = True
cudnn . deterministic = True
from sklearn . model_selection import train_test_split
# writer = SummaryWriter('./TensorBoardX/')
2026-06-10 16:04:02 +08:00
from logs . log import algo_log
2026-06-05 09:34:29 +08:00
# Convolution module
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding ( nn . Module ) :
def __init__ ( self , emb_size = 40 , n_chan = 8 ) :
# self.patch_size = patch_size
super ( ) . __init__ ( )
self . shallownet = nn . Sequential (
nn . Conv2d ( 1 , 40 , ( 1 , 25 ) , ( 1 , 1 ) ) ,
nn . Conv2d ( 40 , 40 , ( n_chan , 1 ) , ( 1 , 1 ) ) ,
nn . BatchNorm2d ( 40 ) ,
nn . ELU ( ) ,
nn . AvgPool2d ( ( 1 , 75 ) , ( 1 , 15 ) ) , # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
nn . Dropout ( 0.5 ) ,
)
self . projection = nn . Sequential (
nn . Conv2d ( 40 , emb_size , ( 1 , 1 ) , stride = ( 1 , 1 ) ) , # transpose, conv could enhance fiting ability slightly
Rearrange ( ' b e (h) (w) -> b (h w) e ' ) ,
)
def forward ( self , x : Tensor ) - > Tensor :
b , _ , _ , _ = x . shape
x = self . shallownet ( x )
x = self . projection ( x )
return x
class MultiHeadAttention ( nn . Module ) :
def __init__ ( self , emb_size , num_heads , dropout ) :
super ( ) . __init__ ( )
self . emb_size = emb_size
self . num_heads = num_heads
self . keys = nn . Linear ( emb_size , emb_size )
self . queries = nn . Linear ( emb_size , emb_size )
self . values = nn . Linear ( emb_size , emb_size )
self . att_drop = nn . Dropout ( dropout )
self . projection = nn . Linear ( emb_size , emb_size )
def forward ( self , x : Tensor , mask : Tensor = None ) - > Tensor :
queries = rearrange ( self . queries ( x ) , " b n (h d) -> b h n d " , h = self . num_heads )
keys = rearrange ( self . keys ( x ) , " b n (h d) -> b h n d " , h = self . num_heads )
values = rearrange ( self . values ( x ) , " b n (h d) -> b h n d " , h = self . num_heads )
energy = torch . einsum ( ' bhqd, bhkd -> bhqk ' , queries , keys )
if mask is not None :
2026-06-08 15:47:25 +08:00
fill_value = torch . finfo ( torch . float64 ) . min
2026-06-05 09:34:29 +08:00
energy . mask_fill ( ~ mask , fill_value )
scaling = self . emb_size * * ( 1 / 2 )
att = F . softmax ( energy / scaling , dim = - 1 )
att = self . att_drop ( att )
out = torch . einsum ( ' bhal, bhlv -> bhav ' , att , values )
out = rearrange ( out , " b h n d -> b n (h d) " )
out = self . projection ( out )
return out
class ResidualAdd ( nn . Module ) :
def __init__ ( self , fn ) :
super ( ) . __init__ ( )
self . fn = fn
def forward ( self , x , * * kwargs ) :
res = x
x = self . fn ( x , * * kwargs )
x + = res
return x
class FeedForwardBlock ( nn . Sequential ) :
def __init__ ( self , emb_size , expansion , drop_p ) :
super ( ) . __init__ (
nn . Linear ( emb_size , expansion * emb_size ) ,
nn . GELU ( ) ,
nn . Dropout ( drop_p ) ,
nn . Linear ( expansion * emb_size , emb_size ) ,
)
class GELU ( nn . Module ) :
def forward ( self , input : Tensor ) - > Tensor :
return input * 0.5 * ( 1.0 + torch . erf ( input / math . sqrt ( 2.0 ) ) )
class TransformerEncoderBlock ( nn . Sequential ) :
def __init__ ( self ,
emb_size ,
num_heads = 10 ,
drop_p = 0.5 ,
forward_expansion = 4 ,
forward_drop_p = 0.5 ) :
super ( ) . __init__ (
ResidualAdd ( nn . Sequential (
nn . LayerNorm ( emb_size ) ,
MultiHeadAttention ( emb_size , num_heads , drop_p ) ,
nn . Dropout ( drop_p )
) ) ,
ResidualAdd ( nn . Sequential (
nn . LayerNorm ( emb_size ) ,
FeedForwardBlock (
emb_size , expansion = forward_expansion , drop_p = forward_drop_p ) ,
nn . Dropout ( drop_p )
)
) )
class TransformerEncoder ( nn . Sequential ) :
def __init__ ( self , depth , emb_size ) :
super ( ) . __init__ ( * [ TransformerEncoderBlock ( emb_size ) for _ in range ( depth ) ] )
class ClassificationHead ( nn . Sequential ) :
def __init__ ( self , emb_size , n_classes ) :
super ( ) . __init__ ( )
# global average pooling
self . clshead = nn . Sequential (
Reduce ( ' b n e -> b e ' , reduction = ' mean ' ) ,
nn . LayerNorm ( emb_size ) ,
nn . Linear ( emb_size , n_classes )
)
self . fc = nn . Sequential (
nn . Linear ( 2440 , 256 ) ,
nn . ELU ( ) ,
nn . Dropout ( 0.5 ) ,
nn . Linear ( 256 , 32 ) ,
nn . ELU ( ) ,
nn . Dropout ( 0.3 ) ,
nn . Linear ( 32 , 2 )
)
def forward ( self , x ) :
x = x . contiguous ( ) . view ( x . size ( 0 ) , - 1 )
out = self . fc ( x )
return out
class Conformer ( nn . Sequential ) :
def __init__ ( self , emb_size = 40 , depth = 6 , n_classes = 2 , n_chan = 8 , * * kwargs ) :
super ( ) . __init__ (
PatchEmbedding ( emb_size , n_chan ) ,
TransformerEncoder ( depth , emb_size ) ,
ClassificationHead ( emb_size , n_classes )
)
class ExP ( ) :
def __init__ ( self , n_chan ) :
super ( ExP , self ) . __init__ ( )
self . n_chan = n_chan
self . batch_size = 24
self . n_epochs = 250
self . c_dim = 4
self . lr = 0.0002
self . b1 = 0.5
self . b2 = 0.999
self . start_epoch = 0
# 创建目录
os . makedirs ( " online_Models " , exist_ok = True )
self . log_write = open ( " ./online_Models/log_result.txt " , " w " )
self . Tensor = torch . cuda . FloatTensor
self . LongTensor = torch . cuda . LongTensor
self . criterion_cls = torch . nn . CrossEntropyLoss ( ) . cuda ( )
self . model = Conformer ( n_chan = self . n_chan ) . cuda ( )
self . model = nn . DataParallel ( self . model , device_ids = [ i for i in range ( len ( gpus ) ) ] )
self . model = self . model . cuda ( )
# self.model = EEGNet().cuda()
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
# self.model = self.model.cuda()
# summary(self.model, (1, 8, 1000))
# Segmentation and Reconstruction (S&R) data augmentation
def interaug ( self , timg , label ) :
# 确保输入是 numpy 数组( CPU)
if isinstance ( timg , torch . Tensor ) :
timg = timg . cpu ( ) . numpy ( )
if isinstance ( label , torch . Tensor ) :
label = label . cpu ( ) . numpy ( )
aug_data = [ ]
aug_label = [ ]
for cls4aug in range ( 2 ) :
cls_idx = np . where ( label == cls4aug + 1 )
tmp_data = timg [ cls_idx ]
tmp_label = label [ cls_idx ]
tmp_aug_data = np . zeros ( ( int ( self . batch_size / 2 ) , 1 , self . n_chan , 1000 ) )
for ri in range ( int ( self . batch_size / 2 ) ) :
for rj in range ( 8 ) :
rand_idx = np . random . randint ( 0 , tmp_data . shape [ 0 ] , 8 )
tmp_aug_data [ ri , : , : , rj * 125 : ( rj + 1 ) * 125 ] = tmp_data [ rand_idx [ rj ] , : , : ,
rj * 125 : ( rj + 1 ) * 125 ]
aug_data . append ( tmp_aug_data )
aug_label . append ( tmp_label [ : int ( self . batch_size / 2 ) ] )
aug_data = np . concatenate ( aug_data )
aug_label = np . concatenate ( aug_label )
aug_shuffle = np . random . permutation ( len ( aug_data ) )
aug_data = aug_data [ aug_shuffle , : , : ]
aug_label = aug_label [ aug_shuffle ]
# 返回 numpy 数组,由调用方决定是否移到 GPU
return aug_data , aug_label
def train ( self , all_data , all_label , model_path ) :
all_data = np . array ( all_data ) ; all_label = np . array ( all_label )
all_data = np . expand_dims ( all_data , axis = 1 )
train_data , test_data , train_label , test_label = train_test_split ( all_data , all_label , test_size = 0.2 ,
random_state = 42 , stratify = all_label , shuffle = True )
# === 优化:一次性预生成增强数据,避免每个 batch 都重复计算 ===
aug_data , aug_label = self . interaug ( train_data , train_label )
# 将原始数据和增强数据合并,再一起打乱
train_data_full = np . concatenate ( [ train_data , aug_data ] , axis = 0 )
train_label_full = np . concatenate ( [ train_label , aug_label ] , axis = 0 )
shuffle_idx = np . random . permutation ( len ( train_data_full ) )
train_data_full = train_data_full [ shuffle_idx ]
train_label_full = train_label_full [ shuffle_idx ]
img = torch . from_numpy ( train_data_full )
label = torch . from_numpy ( train_label_full - 1 )
dataset = torch . utils . data . TensorDataset ( img , label )
self . dataloader = torch . utils . data . DataLoader ( dataset = dataset , batch_size = self . batch_size , shuffle = True )
test_data = torch . from_numpy ( test_data )
test_label = torch . from_numpy ( test_label - 1 )
test_dataset = torch . utils . data . TensorDataset ( test_data , test_label )
self . test_dataloader = torch . utils . data . DataLoader ( dataset = test_dataset , batch_size = self . batch_size , shuffle = True )
# Optimizers
self . optimizer = torch . optim . Adam ( self . model . parameters ( ) , lr = self . lr , betas = ( self . b1 , self . b2 ) )
test_data = Variable ( test_data . type ( self . Tensor ) )
test_label = Variable ( test_label . type ( self . LongTensor ) )
bestAcc = 0
averAcc = 0
num = 0
Y_true = 0
Y_pred = 0
# Train the cnn model
for e in range ( self . n_epochs ) :
# in_epoch = time.time()
self . model . train ( )
for i , ( img , label ) in enumerate ( self . dataloader ) :
img = Variable ( img . cuda ( ) . type ( self . Tensor ) )
label = Variable ( label . cuda ( ) . type ( self . LongTensor ) )
outputs = self . model ( img )
loss = self . criterion_cls ( outputs , label )
self . optimizer . zero_grad ( )
loss . backward ( )
self . optimizer . step ( )
# out_epoch = time.time()
# test process
if ( e + 1 ) % 1 == 0 :
self . model . eval ( )
Cls = self . model ( test_data )
loss_test = self . criterion_cls ( Cls , test_label )
y_pred = torch . max ( Cls , 1 ) [ 1 ]
acc = float ( ( y_pred == test_label ) . cpu ( ) . numpy ( ) . astype ( int ) . sum ( ) ) / float ( test_label . size ( 0 ) )
train_pred = torch . max ( outputs , 1 ) [ 1 ]
train_acc = float ( ( train_pred == label ) . cpu ( ) . numpy ( ) . astype ( int ) . sum ( ) ) / float ( label . size ( 0 ) )
2026-06-11 11:06:59 +08:00
algo_log ( f " Epoch = { e } , Train loss = { loss . detach ( ) . cpu ( ) . numpy ( ) : .6f } , Test loss = { loss_test . detach ( ) . cpu ( ) . numpy ( ) : .6f } , Train accuracy = { train_acc : .6f } , Test accuracy = { acc : .6f } " , level = " debug " )
2026-06-05 09:34:29 +08:00
self . log_write . write ( str ( e ) + " " + str ( acc ) + " \n " )
num = num + 1
averAcc = averAcc + acc
if acc > bestAcc :
bestAcc = acc
Y_true = test_label
Y_pred = y_pred
torch . save ( self . model , model_path )
averAcc = averAcc / num
2026-06-11 11:06:59 +08:00
algo_log ( f " The average accuracy is: { averAcc } " , level = " debug " )
algo_log ( f " The best accuracy is: { bestAcc } " , level = " debug " )
2026-06-05 09:34:29 +08:00
self . log_write . write ( ' The average accuracy is: ' + str ( averAcc ) + " \n " )
self . log_write . write ( ' The best accuracy is: ' + str ( bestAcc ) + " \n " )
return bestAcc , averAcc , Y_true , Y_pred
# writer.close()
def onlineTrain ( data_queue , result_queue ) :
import torch
2026-06-10 16:04:02 +08:00
algo_log ( f " [DEBUG] torch.__version__ = { torch . __version__ } " , level = " debug " )
algo_log ( f " [DEBUG] torch.cuda.is_available() = { torch . cuda . is_available ( ) } " , level = " debug " )
2026-06-05 09:34:29 +08:00
if torch . cuda . is_available ( ) :
2026-06-10 16:04:02 +08:00
algo_log ( f " [DEBUG] GPU = { torch . cuda . get_device_name ( 0 ) } " , level = " debug " )
2026-06-05 09:34:29 +08:00
try :
starttime = datetime . datetime . now ( )
# seed_n = np.random.randint(2025)
seed_n = 1877
random . seed ( seed_n )
np . random . seed ( seed_n )
torch . manual_seed ( seed_n )
torch . cuda . manual_seed ( seed_n )
torch . cuda . manual_seed_all ( seed_n )
# 从队列获取训练数据
data = data_queue . get ( timeout = 30 )
all_data , all_label , model_path , n_chan = data [ ' data ' ] , data [ ' label ' ] , data [ ' modelPath ' ] , data [ ' n_chan ' ]
exp = ExP ( n_chan )
2026-06-11 11:06:59 +08:00
algo_log ( f " 训练参数: { np . shape ( all_data ) } , { np . shape ( all_label ) } , { model_path } " , level = " debug " )
2026-06-05 09:34:29 +08:00
bestAcc , averAcc , Y_true , Y_pred = exp . train ( all_data , all_label , model_path )
2026-06-11 11:06:59 +08:00
algo_log ( f " THE BEST ACCURACY IS { str ( bestAcc ) } " , level = " debug " )
2026-06-05 09:34:29 +08:00
endtime = datetime . datetime . now ( )
2026-06-11 11:06:59 +08:00
algo_log ( f " train duration: { endtime - starttime } " , level = " debug " )
2026-06-05 09:34:29 +08:00
# 将模型或参数传回
result_queue . put ( {
' status ' : ' success ' ,
' model_state ' : model_path , # 或保存路径
' timestamp ' : time . time ( )
} )
except Exception as e :
result_queue . put ( { ' status ' : ' error ' , ' msg ' : str ( e ) } )
def offlineTrain ( all_data , all_label , modelPath ) :
starttime = datetime . datetime . now ( )
# seed_n = np.random.randint(2025)
seed_n = 1877
2026-06-11 11:06:59 +08:00
algo_log ( f " seed is { seed_n } " , level = " debug " )
2026-06-05 09:34:29 +08:00
random . seed ( seed_n )
np . random . seed ( seed_n )
torch . manual_seed ( seed_n )
torch . cuda . manual_seed ( seed_n )
torch . cuda . manual_seed_all ( seed_n )
exp = ExP ( )
bestAcc , averAcc , Y_true , Y_pred = exp . train ( all_data , all_label , modelPath )
2026-06-10 16:04:02 +08:00
algo_log ( ' THE BEST ACCURACY IS ' + str ( bestAcc ) , level = " debug " )
2026-06-05 09:34:29 +08:00
endtime = datetime . datetime . now ( )
2026-06-11 11:06:59 +08:00
algo_log ( f " train duration: { endtime - starttime } " , level = " debug " )
2026-06-05 09:34:29 +08:00
if __name__ == " __main__ " :
2026-06-10 16:04:02 +08:00
algo_log ( f " [DEBUG] time.asctime(time.localtime(time.time())) = { time . asctime ( time . localtime ( time . time ( ) ) ) } " , level = " debug " )