Files
bci_algo/MI/Algorithm/conformer_2class.py
2026-06-10 16:04:02 +08:00

409 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
EEG Conformer
Convolutional Transformer for EEG decoding
Couple CNN and Transformer in a concise manner with amazing results
"""
# remember to change paths
import os
gpus = [0]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
import numpy as np
import math
import random
import time
import datetime
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
# from common_spatial_pattern import csp
# from torch.utils.tensorboard import SummaryWriter
from torch.backends import cudnn
cudnn.benchmark = True
cudnn.deterministic = True
from sklearn.model_selection import train_test_split
# writer = SummaryWriter('./TensorBoardX/')
from logs.log import algo_log
# Convolution module
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
def __init__(self, emb_size=40,n_chan=8):
# self.patch_size = patch_size
super().__init__()
self.shallownet = nn.Sequential(
nn.Conv2d(1, 40, (1, 25), (1, 1)),
nn.Conv2d(40, 40, (n_chan, 1), (1, 1)),
nn.BatchNorm2d(40),
nn.ELU(),
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
nn.Dropout(0.5),
)
self.projection = nn.Sequential(
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
Rearrange('b e (h) (w) -> b (h w) e'),
)
def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.shallownet(x)
x = self.projection(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads, dropout):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
fill_value = torch.finfo(torch.float64).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)
att = F.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion, drop_p):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
class GELU(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size,
num_heads=10,
drop_p=0.5,
forward_expansion=4,
forward_drop_p=0.5):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, num_heads, drop_p),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))
class TransformerEncoder(nn.Sequential):
def __init__(self, depth, emb_size):
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size, n_classes):
super().__init__()
# global average pooling
self.clshead = nn.Sequential(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes)
)
self.fc = nn.Sequential(
nn.Linear(2440, 256),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(256, 32),
nn.ELU(),
nn.Dropout(0.3),
nn.Linear(32, 2)
)
def forward(self, x):
x = x.contiguous().view(x.size(0), -1)
out = self.fc(x)
return out
class Conformer(nn.Sequential):
def __init__(self, emb_size=40, depth=6, n_classes=2,n_chan=8, **kwargs):
super().__init__(
PatchEmbedding(emb_size,n_chan),
TransformerEncoder(depth, emb_size),
ClassificationHead(emb_size, n_classes)
)
class ExP():
def __init__(self,n_chan):
super(ExP, self).__init__()
self.n_chan = n_chan
self.batch_size = 24
self.n_epochs = 250
self.c_dim = 4
self.lr = 0.0002
self.b1 = 0.5
self.b2 = 0.999
self.start_epoch = 0
# 创建目录
os.makedirs("online_Models", exist_ok=True)
self.log_write = open("./online_Models/log_result.txt", "w")
self.Tensor = torch.cuda.FloatTensor
self.LongTensor = torch.cuda.LongTensor
self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
self.model = Conformer(n_chan=self.n_chan).cuda()
self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
self.model = self.model.cuda()
# self.model = EEGNet().cuda()
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
# self.model = self.model.cuda()
# summary(self.model, (1, 8, 1000))
# Segmentation and Reconstruction (S&R) data augmentation
def interaug(self, timg, label):
# 确保输入是 numpy 数组CPU
if isinstance(timg, torch.Tensor):
timg = timg.cpu().numpy()
if isinstance(label, torch.Tensor):
label = label.cpu().numpy()
aug_data = []
aug_label = []
for cls4aug in range(2):
cls_idx = np.where(label == cls4aug + 1)
tmp_data = timg[cls_idx]
tmp_label = label[cls_idx]
tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, self.n_chan, 1000))
for ri in range(int(self.batch_size / 2)):
for rj in range(8):
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
rj * 125:(rj + 1) * 125]
aug_data.append(tmp_aug_data)
aug_label.append(tmp_label[:int(self.batch_size / 2)])
aug_data = np.concatenate(aug_data)
aug_label = np.concatenate(aug_label)
aug_shuffle = np.random.permutation(len(aug_data))
aug_data = aug_data[aug_shuffle, :, :]
aug_label = aug_label[aug_shuffle]
# 返回 numpy 数组,由调用方决定是否移到 GPU
return aug_data, aug_label
def train(self,all_data,all_label,model_path):
all_data = np.array(all_data);all_label = np.array(all_label)
all_data = np.expand_dims(all_data, axis=1)
train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2,
random_state=42, stratify=all_label,shuffle=True)
# === 优化:一次性预生成增强数据,避免每个 batch 都重复计算 ===
aug_data, aug_label = self.interaug(train_data, train_label)
# 将原始数据和增强数据合并,再一起打乱
train_data_full = np.concatenate([train_data, aug_data], axis=0)
train_label_full = np.concatenate([train_label, aug_label], axis=0)
shuffle_idx = np.random.permutation(len(train_data_full))
train_data_full = train_data_full[shuffle_idx]
train_label_full = train_label_full[shuffle_idx]
img = torch.from_numpy(train_data_full)
label = torch.from_numpy(train_label_full-1)
dataset = torch.utils.data.TensorDataset(img, label)
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
test_data = torch.from_numpy(test_data)
test_label = torch.from_numpy(test_label-1)
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
# Optimizers
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
test_data = Variable(test_data.type(self.Tensor))
test_label = Variable(test_label.type(self.LongTensor))
bestAcc = 0
averAcc = 0
num = 0
Y_true = 0
Y_pred = 0
# Train the cnn model
for e in range(self.n_epochs):
# in_epoch = time.time()
self.model.train()
for i, (img, label) in enumerate(self.dataloader):
img = Variable(img.cuda().type(self.Tensor))
label = Variable(label.cuda().type(self.LongTensor))
outputs = self.model(img)
loss = self.criterion_cls(outputs, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# out_epoch = time.time()
# test process
if (e + 1) % 1 == 0:
self.model.eval()
Cls = self.model(test_data)
loss_test = self.criterion_cls(Cls, test_label)
y_pred = torch.max(Cls, 1)[1]
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
algo_log('Epoch:', e,
' Train loss: %.6f' % loss.detach().cpu().numpy(),
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
' Train accuracy %.6f' % train_acc,
' Test accuracy is %.6f' % acc, level="debug")
self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1
averAcc = averAcc + acc
if acc > bestAcc:
bestAcc = acc
Y_true = test_label
Y_pred = y_pred
torch.save(self.model, model_path)
averAcc = averAcc / num
algo_log('The average accuracy is:', averAcc, level="debug")
algo_log('The best accuracy is:', bestAcc, level="debug")
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
return bestAcc, averAcc, Y_true, Y_pred
# writer.close()
def onlineTrain(data_queue,result_queue):
import torch
algo_log(f"[DEBUG] torch.__version__ = {torch.__version__}", level="debug")
algo_log(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}", level="debug")
if torch.cuda.is_available():
algo_log(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}", level="debug")
try:
starttime = datetime.datetime.now()
# seed_n = np.random.randint(2025)
seed_n = 1877
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
# 从队列获取训练数据
data = data_queue.get(timeout=30)
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
exp = ExP(n_chan)
algo_log('训练参数: ',np.shape(all_data),np.shape(all_label),model_path, level="debug")
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
endtime = datetime.datetime.now()
algo_log('train duration: ',str(endtime - starttime), level="debug")
# 将模型或参数传回
result_queue.put({
'status': 'success',
'model_state': model_path, # 或保存路径
'timestamp': time.time()
})
except Exception as e:
result_queue.put({'status': 'error', 'msg': str(e)})
def offlineTrain(all_data,all_label,modelPath):
starttime = datetime.datetime.now()
# seed_n = np.random.randint(2025)
seed_n = 1877
algo_log('seed is ' + str(seed_n), level="debug")
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
exp = ExP()
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
endtime = datetime.datetime.now()
algo_log('train duration: ',str(endtime - starttime), level="debug")
if __name__ == "__main__":
algo_log(f"[DEBUG] time.asctime(time.localtime(time.time())) = {time.asctime(time.localtime(time.time()))}", level="debug")