410 lines
14 KiB
Python
410 lines
14 KiB
Python
|
|
"""
|
|||
|
|
EEG Conformer
|
|||
|
|
|
|||
|
|
Convolutional Transformer for EEG decoding
|
|||
|
|
|
|||
|
|
Couple CNN and Transformer in a concise manner with amazing results
|
|||
|
|
"""
|
|||
|
|
# remember to change paths
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
gpus = [0]
|
|||
|
|
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
|||
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
|
|||
|
|
import numpy as np
|
|||
|
|
import math
|
|||
|
|
import random
|
|||
|
|
import time
|
|||
|
|
import datetime
|
|||
|
|
|
|||
|
|
from torch.utils.data import DataLoader
|
|||
|
|
from torch.autograd import Variable
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from torch import nn
|
|||
|
|
from torch import Tensor
|
|||
|
|
from einops import rearrange
|
|||
|
|
from einops.layers.torch import Rearrange, Reduce
|
|||
|
|
# from common_spatial_pattern import csp
|
|||
|
|
|
|||
|
|
# from torch.utils.tensorboard import SummaryWriter
|
|||
|
|
from torch.backends import cudnn
|
|||
|
|
cudnn.benchmark = True
|
|||
|
|
cudnn.deterministic = True
|
|||
|
|
from sklearn.model_selection import train_test_split
|
|||
|
|
# writer = SummaryWriter('./TensorBoardX/')
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Convolution module
|
|||
|
|
# use conv to capture local features, instead of postion embedding.
|
|||
|
|
class PatchEmbedding(nn.Module):
|
|||
|
|
def __init__(self, emb_size=40,n_chan=8):
|
|||
|
|
# self.patch_size = patch_size
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
self.shallownet = nn.Sequential(
|
|||
|
|
nn.Conv2d(1, 40, (1, 25), (1, 1)),
|
|||
|
|
nn.Conv2d(40, 40, (n_chan, 1), (1, 1)),
|
|||
|
|
nn.BatchNorm2d(40),
|
|||
|
|
nn.ELU(),
|
|||
|
|
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
|
|||
|
|
nn.Dropout(0.5),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.projection = nn.Sequential(
|
|||
|
|
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
|
|||
|
|
Rearrange('b e (h) (w) -> b (h w) e'),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def forward(self, x: Tensor) -> Tensor:
|
|||
|
|
b, _, _, _ = x.shape
|
|||
|
|
x = self.shallownet(x)
|
|||
|
|
x = self.projection(x)
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MultiHeadAttention(nn.Module):
|
|||
|
|
def __init__(self, emb_size, num_heads, dropout):
|
|||
|
|
super().__init__()
|
|||
|
|
self.emb_size = emb_size
|
|||
|
|
self.num_heads = num_heads
|
|||
|
|
self.keys = nn.Linear(emb_size, emb_size)
|
|||
|
|
self.queries = nn.Linear(emb_size, emb_size)
|
|||
|
|
self.values = nn.Linear(emb_size, emb_size)
|
|||
|
|
self.att_drop = nn.Dropout(dropout)
|
|||
|
|
self.projection = nn.Linear(emb_size, emb_size)
|
|||
|
|
|
|||
|
|
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
|
|||
|
|
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
|
|||
|
|
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
|
|||
|
|
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
|||
|
|
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
|||
|
|
if mask is not None:
|
|||
|
|
fill_value = torch.finfo(torch.float32).min
|
|||
|
|
energy.mask_fill(~mask, fill_value)
|
|||
|
|
|
|||
|
|
scaling = self.emb_size ** (1 / 2)
|
|||
|
|
att = F.softmax(energy / scaling, dim=-1)
|
|||
|
|
att = self.att_drop(att)
|
|||
|
|
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
|
|||
|
|
out = rearrange(out, "b h n d -> b n (h d)")
|
|||
|
|
out = self.projection(out)
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ResidualAdd(nn.Module):
|
|||
|
|
def __init__(self, fn):
|
|||
|
|
super().__init__()
|
|||
|
|
self.fn = fn
|
|||
|
|
|
|||
|
|
def forward(self, x, **kwargs):
|
|||
|
|
res = x
|
|||
|
|
x = self.fn(x, **kwargs)
|
|||
|
|
x += res
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FeedForwardBlock(nn.Sequential):
|
|||
|
|
def __init__(self, emb_size, expansion, drop_p):
|
|||
|
|
super().__init__(
|
|||
|
|
nn.Linear(emb_size, expansion * emb_size),
|
|||
|
|
nn.GELU(),
|
|||
|
|
nn.Dropout(drop_p),
|
|||
|
|
nn.Linear(expansion * emb_size, emb_size),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class GELU(nn.Module):
|
|||
|
|
def forward(self, input: Tensor) -> Tensor:
|
|||
|
|
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TransformerEncoderBlock(nn.Sequential):
|
|||
|
|
def __init__(self,
|
|||
|
|
emb_size,
|
|||
|
|
num_heads=10,
|
|||
|
|
drop_p=0.5,
|
|||
|
|
forward_expansion=4,
|
|||
|
|
forward_drop_p=0.5):
|
|||
|
|
super().__init__(
|
|||
|
|
ResidualAdd(nn.Sequential(
|
|||
|
|
nn.LayerNorm(emb_size),
|
|||
|
|
MultiHeadAttention(emb_size, num_heads, drop_p),
|
|||
|
|
nn.Dropout(drop_p)
|
|||
|
|
)),
|
|||
|
|
ResidualAdd(nn.Sequential(
|
|||
|
|
nn.LayerNorm(emb_size),
|
|||
|
|
FeedForwardBlock(
|
|||
|
|
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
|
|||
|
|
nn.Dropout(drop_p)
|
|||
|
|
)
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TransformerEncoder(nn.Sequential):
|
|||
|
|
def __init__(self, depth, emb_size):
|
|||
|
|
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ClassificationHead(nn.Sequential):
|
|||
|
|
def __init__(self, emb_size, n_classes):
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
# global average pooling
|
|||
|
|
self.clshead = nn.Sequential(
|
|||
|
|
Reduce('b n e -> b e', reduction='mean'),
|
|||
|
|
nn.LayerNorm(emb_size),
|
|||
|
|
nn.Linear(emb_size, n_classes)
|
|||
|
|
)
|
|||
|
|
self.fc = nn.Sequential(
|
|||
|
|
nn.Linear(2440, 256),
|
|||
|
|
nn.ELU(),
|
|||
|
|
nn.Dropout(0.5),
|
|||
|
|
nn.Linear(256, 32),
|
|||
|
|
nn.ELU(),
|
|||
|
|
nn.Dropout(0.3),
|
|||
|
|
nn.Linear(32, 2)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
x = x.contiguous().view(x.size(0), -1)
|
|||
|
|
out = self.fc(x)
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Conformer(nn.Sequential):
|
|||
|
|
def __init__(self, emb_size=40, depth=6, n_classes=2,n_chan=8, **kwargs):
|
|||
|
|
super().__init__(
|
|||
|
|
|
|||
|
|
PatchEmbedding(emb_size,n_chan),
|
|||
|
|
TransformerEncoder(depth, emb_size),
|
|||
|
|
ClassificationHead(emb_size, n_classes)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ExP():
|
|||
|
|
def __init__(self,n_chan):
|
|||
|
|
super(ExP, self).__init__()
|
|||
|
|
self.n_chan = n_chan
|
|||
|
|
self.batch_size = 24
|
|||
|
|
self.n_epochs = 250
|
|||
|
|
self.c_dim = 4
|
|||
|
|
self.lr = 0.0002
|
|||
|
|
self.b1 = 0.5
|
|||
|
|
self.b2 = 0.999
|
|||
|
|
|
|||
|
|
self.start_epoch = 0
|
|||
|
|
# 创建目录
|
|||
|
|
os.makedirs("online_Models", exist_ok=True)
|
|||
|
|
self.log_write = open("./online_Models/log_result.txt", "w")
|
|||
|
|
|
|||
|
|
|
|||
|
|
self.Tensor = torch.cuda.FloatTensor
|
|||
|
|
self.LongTensor = torch.cuda.LongTensor
|
|||
|
|
|
|||
|
|
self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
|
|||
|
|
|
|||
|
|
self.model = Conformer(n_chan=self.n_chan).cuda()
|
|||
|
|
self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
|
|||
|
|
self.model = self.model.cuda()
|
|||
|
|
|
|||
|
|
# self.model = EEGNet().cuda()
|
|||
|
|
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
|
|||
|
|
# self.model = self.model.cuda()
|
|||
|
|
# summary(self.model, (1, 8, 1000))
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Segmentation and Reconstruction (S&R) data augmentation
|
|||
|
|
def interaug(self, timg, label):
|
|||
|
|
# 确保输入是 numpy 数组(CPU)
|
|||
|
|
if isinstance(timg, torch.Tensor):
|
|||
|
|
timg = timg.cpu().numpy()
|
|||
|
|
if isinstance(label, torch.Tensor):
|
|||
|
|
label = label.cpu().numpy()
|
|||
|
|
|
|||
|
|
aug_data = []
|
|||
|
|
aug_label = []
|
|||
|
|
for cls4aug in range(2):
|
|||
|
|
cls_idx = np.where(label == cls4aug + 1)
|
|||
|
|
tmp_data = timg[cls_idx]
|
|||
|
|
tmp_label = label[cls_idx]
|
|||
|
|
tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, self.n_chan, 1000))
|
|||
|
|
for ri in range(int(self.batch_size / 2)):
|
|||
|
|
for rj in range(8):
|
|||
|
|
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
|
|||
|
|
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
|
|||
|
|
rj * 125:(rj + 1) * 125]
|
|||
|
|
|
|||
|
|
aug_data.append(tmp_aug_data)
|
|||
|
|
aug_label.append(tmp_label[:int(self.batch_size / 2)])
|
|||
|
|
aug_data = np.concatenate(aug_data)
|
|||
|
|
aug_label = np.concatenate(aug_label)
|
|||
|
|
aug_shuffle = np.random.permutation(len(aug_data))
|
|||
|
|
aug_data = aug_data[aug_shuffle, :, :]
|
|||
|
|
aug_label = aug_label[aug_shuffle]
|
|||
|
|
|
|||
|
|
# 返回 numpy 数组,由调用方决定是否移到 GPU
|
|||
|
|
return aug_data, aug_label
|
|||
|
|
|
|||
|
|
def train(self,all_data,all_label,model_path):
|
|||
|
|
all_data = np.array(all_data);all_label = np.array(all_label)
|
|||
|
|
all_data = np.expand_dims(all_data, axis=1)
|
|||
|
|
train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2,
|
|||
|
|
random_state=42, stratify=all_label,shuffle=True)
|
|||
|
|
|
|||
|
|
# === 优化:一次性预生成增强数据,避免每个 batch 都重复计算 ===
|
|||
|
|
aug_data, aug_label = self.interaug(train_data, train_label)
|
|||
|
|
# 将原始数据和增强数据合并,再一起打乱
|
|||
|
|
train_data_full = np.concatenate([train_data, aug_data], axis=0)
|
|||
|
|
train_label_full = np.concatenate([train_label, aug_label], axis=0)
|
|||
|
|
shuffle_idx = np.random.permutation(len(train_data_full))
|
|||
|
|
train_data_full = train_data_full[shuffle_idx]
|
|||
|
|
train_label_full = train_label_full[shuffle_idx]
|
|||
|
|
|
|||
|
|
img = torch.from_numpy(train_data_full)
|
|||
|
|
label = torch.from_numpy(train_label_full-1)
|
|||
|
|
|
|||
|
|
dataset = torch.utils.data.TensorDataset(img, label)
|
|||
|
|
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
|
|||
|
|
|
|||
|
|
test_data = torch.from_numpy(test_data)
|
|||
|
|
test_label = torch.from_numpy(test_label-1)
|
|||
|
|
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
|
|||
|
|
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
|
|||
|
|
|
|||
|
|
# Optimizers
|
|||
|
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
|
|||
|
|
|
|||
|
|
test_data = Variable(test_data.type(self.Tensor))
|
|||
|
|
test_label = Variable(test_label.type(self.LongTensor))
|
|||
|
|
|
|||
|
|
bestAcc = 0
|
|||
|
|
averAcc = 0
|
|||
|
|
num = 0
|
|||
|
|
Y_true = 0
|
|||
|
|
Y_pred = 0
|
|||
|
|
|
|||
|
|
# Train the cnn model
|
|||
|
|
for e in range(self.n_epochs):
|
|||
|
|
# in_epoch = time.time()
|
|||
|
|
self.model.train()
|
|||
|
|
for i, (img, label) in enumerate(self.dataloader):
|
|||
|
|
|
|||
|
|
img = Variable(img.cuda().type(self.Tensor))
|
|||
|
|
label = Variable(label.cuda().type(self.LongTensor))
|
|||
|
|
|
|||
|
|
outputs = self.model(img)
|
|||
|
|
|
|||
|
|
loss = self.criterion_cls(outputs, label)
|
|||
|
|
|
|||
|
|
self.optimizer.zero_grad()
|
|||
|
|
loss.backward()
|
|||
|
|
self.optimizer.step()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# out_epoch = time.time()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# test process
|
|||
|
|
if (e + 1) % 1 == 0:
|
|||
|
|
self.model.eval()
|
|||
|
|
Cls = self.model(test_data)
|
|||
|
|
|
|||
|
|
loss_test = self.criterion_cls(Cls, test_label)
|
|||
|
|
y_pred = torch.max(Cls, 1)[1]
|
|||
|
|
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
|
|||
|
|
train_pred = torch.max(outputs, 1)[1]
|
|||
|
|
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
|||
|
|
|
|||
|
|
print('Epoch:', e,
|
|||
|
|
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
|||
|
|
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
|||
|
|
' Train accuracy %.6f' % train_acc,
|
|||
|
|
' Test accuracy is %.6f' % acc)
|
|||
|
|
|
|||
|
|
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
|||
|
|
num = num + 1
|
|||
|
|
averAcc = averAcc + acc
|
|||
|
|
if acc > bestAcc:
|
|||
|
|
bestAcc = acc
|
|||
|
|
Y_true = test_label
|
|||
|
|
Y_pred = y_pred
|
|||
|
|
|
|||
|
|
|
|||
|
|
torch.save(self.model, model_path)
|
|||
|
|
averAcc = averAcc / num
|
|||
|
|
print('The average accuracy is:', averAcc)
|
|||
|
|
print('The best accuracy is:', bestAcc)
|
|||
|
|
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
|||
|
|
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
|||
|
|
|
|||
|
|
return bestAcc, averAcc, Y_true, Y_pred
|
|||
|
|
# writer.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def onlineTrain(data_queue,result_queue):
|
|||
|
|
import torch
|
|||
|
|
print(f"[DEBUG] torch.__version__ = {torch.__version__}")
|
|||
|
|
print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}")
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}")
|
|||
|
|
try:
|
|||
|
|
starttime = datetime.datetime.now()
|
|||
|
|
|
|||
|
|
# seed_n = np.random.randint(2025)
|
|||
|
|
seed_n = 1877
|
|||
|
|
random.seed(seed_n)
|
|||
|
|
np.random.seed(seed_n)
|
|||
|
|
torch.manual_seed(seed_n)
|
|||
|
|
torch.cuda.manual_seed(seed_n)
|
|||
|
|
torch.cuda.manual_seed_all(seed_n)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 从队列获取训练数据
|
|||
|
|
data = data_queue.get(timeout=30)
|
|||
|
|
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
|||
|
|
exp = ExP(n_chan)
|
|||
|
|
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
|
|||
|
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
|||
|
|
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
|||
|
|
|
|||
|
|
endtime = datetime.datetime.now()
|
|||
|
|
print('train duration: ',str(endtime - starttime))
|
|||
|
|
|
|||
|
|
# 将模型或参数传回
|
|||
|
|
result_queue.put({
|
|||
|
|
'status': 'success',
|
|||
|
|
'model_state': model_path, # 或保存路径
|
|||
|
|
'timestamp': time.time()
|
|||
|
|
})
|
|||
|
|
except Exception as e:
|
|||
|
|
result_queue.put({'status': 'error', 'msg': str(e)})
|
|||
|
|
|
|||
|
|
def offlineTrain(all_data,all_label,modelPath):
|
|||
|
|
starttime = datetime.datetime.now()
|
|||
|
|
|
|||
|
|
# seed_n = np.random.randint(2025)
|
|||
|
|
seed_n = 1877
|
|||
|
|
print('seed is ' + str(seed_n))
|
|||
|
|
random.seed(seed_n)
|
|||
|
|
np.random.seed(seed_n)
|
|||
|
|
torch.manual_seed(seed_n)
|
|||
|
|
torch.cuda.manual_seed(seed_n)
|
|||
|
|
torch.cuda.manual_seed_all(seed_n)
|
|||
|
|
|
|||
|
|
exp = ExP()
|
|||
|
|
|
|||
|
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
|||
|
|
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
|||
|
|
|
|||
|
|
endtime = datetime.datetime.now()
|
|||
|
|
print('train duration: ',str(endtime - starttime))
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
print(time.asctime(time.localtime(time.time())))
|
|||
|
|
print(time.asctime(time.localtime(time.time())))
|