replace print with algo_log
This commit is contained in:
@@ -14,8 +14,8 @@ from torch.autograd import Variable
|
|||||||
# from Device.SunnyLinker import SunnyLinker64
|
# from Device.SunnyLinker import SunnyLinker64
|
||||||
from SSMVEP.algorithm.tdca import TDCA
|
from SSMVEP.algorithm.tdca import TDCA
|
||||||
from SSMVEP.algorithm.base import generate_cca_references
|
from SSMVEP.algorithm.base import generate_cca_references
|
||||||
from concentration.algorithm.calculate_focus import Calculate
|
# from concentration.algorithm.calculate_focus import Calculate
|
||||||
from blinkdetection.algorithm.eye_detection import blink_detection
|
# from blinkdetection.algorithm.eye_detection import blink_detection
|
||||||
from Zmq.zmqServer import zmqServer
|
from Zmq.zmqServer import zmqServer
|
||||||
from Zmq.zmqClient import zmqClient
|
from Zmq.zmqClient import zmqClient
|
||||||
from MI.Algorithm.conformer_2class import onlineTrain
|
from MI.Algorithm.conformer_2class import onlineTrain
|
||||||
@@ -425,7 +425,7 @@ class Decoder_main(threading.Thread):
|
|||||||
algo_log(f"MI运动意图识别: {y_pred}")
|
algo_log(f"MI运动意图识别: {y_pred}")
|
||||||
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
|
self.zmqServer.broadcast_message('paradigm', int(y_pred.item()))
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。')
|
||||||
else: # 休息状态
|
else: # 休息状态
|
||||||
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25:
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ cudnn.benchmark = True
|
|||||||
cudnn.deterministic = True
|
cudnn.deterministic = True
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
# writer = SummaryWriter('./TensorBoardX/')
|
# writer = SummaryWriter('./TensorBoardX/')
|
||||||
|
from logs.log import algo_log
|
||||||
|
|
||||||
# Convolution module
|
# Convolution module
|
||||||
# use conv to capture local features, instead of postion embedding.
|
# use conv to capture local features, instead of postion embedding.
|
||||||
@@ -318,11 +318,11 @@ class ExP():
|
|||||||
train_pred = torch.max(outputs, 1)[1]
|
train_pred = torch.max(outputs, 1)[1]
|
||||||
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
||||||
|
|
||||||
print('Epoch:', e,
|
algo_log('Epoch:', e,
|
||||||
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
||||||
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
||||||
' Train accuracy %.6f' % train_acc,
|
' Train accuracy %.6f' % train_acc,
|
||||||
' Test accuracy is %.6f' % acc)
|
' Test accuracy is %.6f' % acc, level="debug")
|
||||||
|
|
||||||
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
||||||
num = num + 1
|
num = num + 1
|
||||||
@@ -335,8 +335,8 @@ class ExP():
|
|||||||
|
|
||||||
torch.save(self.model, model_path)
|
torch.save(self.model, model_path)
|
||||||
averAcc = averAcc / num
|
averAcc = averAcc / num
|
||||||
print('The average accuracy is:', averAcc)
|
algo_log('The average accuracy is:', averAcc, level="debug")
|
||||||
print('The best accuracy is:', bestAcc)
|
algo_log('The best accuracy is:', bestAcc, level="debug")
|
||||||
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
||||||
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
||||||
|
|
||||||
@@ -346,10 +346,10 @@ class ExP():
|
|||||||
|
|
||||||
def onlineTrain(data_queue,result_queue):
|
def onlineTrain(data_queue,result_queue):
|
||||||
import torch
|
import torch
|
||||||
print(f"[DEBUG] torch.__version__ = {torch.__version__}")
|
algo_log(f"[DEBUG] torch.__version__ = {torch.__version__}", level="debug")
|
||||||
print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}")
|
algo_log(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}", level="debug")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}")
|
algo_log(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}", level="debug")
|
||||||
try:
|
try:
|
||||||
starttime = datetime.datetime.now()
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
@@ -366,12 +366,12 @@ def onlineTrain(data_queue,result_queue):
|
|||||||
data = data_queue.get(timeout=30)
|
data = data_queue.get(timeout=30)
|
||||||
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
||||||
exp = ExP(n_chan)
|
exp = ExP(n_chan)
|
||||||
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
|
algo_log('训练参数: ',np.shape(all_data),np.shape(all_label),model_path, level="debug")
|
||||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
||||||
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
|
||||||
|
|
||||||
endtime = datetime.datetime.now()
|
endtime = datetime.datetime.now()
|
||||||
print('train duration: ',str(endtime - starttime))
|
algo_log('train duration: ',str(endtime - starttime), level="debug")
|
||||||
|
|
||||||
# 将模型或参数传回
|
# 将模型或参数传回
|
||||||
result_queue.put({
|
result_queue.put({
|
||||||
@@ -387,7 +387,7 @@ def offlineTrain(all_data,all_label,modelPath):
|
|||||||
|
|
||||||
# seed_n = np.random.randint(2025)
|
# seed_n = np.random.randint(2025)
|
||||||
seed_n = 1877
|
seed_n = 1877
|
||||||
print('seed is ' + str(seed_n))
|
algo_log('seed is ' + str(seed_n), level="debug")
|
||||||
random.seed(seed_n)
|
random.seed(seed_n)
|
||||||
np.random.seed(seed_n)
|
np.random.seed(seed_n)
|
||||||
torch.manual_seed(seed_n)
|
torch.manual_seed(seed_n)
|
||||||
@@ -397,13 +397,12 @@ def offlineTrain(all_data,all_label,modelPath):
|
|||||||
exp = ExP()
|
exp = ExP()
|
||||||
|
|
||||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
||||||
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
|
||||||
|
|
||||||
endtime = datetime.datetime.now()
|
endtime = datetime.datetime.now()
|
||||||
print('train duration: ',str(endtime - starttime))
|
algo_log('train duration: ',str(endtime - starttime), level="debug")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(time.asctime(time.localtime(time.time())))
|
algo_log(f"[DEBUG] time.asctime(time.localtime(time.time())) = {time.asctime(time.localtime(time.time()))}", level="debug")
|
||||||
print(time.asctime(time.localtime(time.time())))
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from einops import rearrange
|
|||||||
from einops.layers.torch import Rearrange, Reduce
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
from torch.backends import cudnn
|
from torch.backends import cudnn
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
from logs.log import algo_log
|
||||||
# writer = SummaryWriter('./TensorBoardX/')
|
# writer = SummaryWriter('./TensorBoardX/')
|
||||||
|
|
||||||
|
|
||||||
@@ -190,7 +191,7 @@ class ExP():
|
|||||||
# 自动选择设备:有 GPU 用 GPU,否则用 CPU
|
# 自动选择设备:有 GPU 用 GPU,否则用 CPU
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
# self.device = torch.device("cpu")
|
# self.device = torch.device("cpu")
|
||||||
print(f"Using device: {self.device}")
|
algo_log(f"Using device: {self.device}", level="debug")
|
||||||
|
|
||||||
# 定义张量类型(不再强制使用 cuda)
|
# 定义张量类型(不再强制使用 cuda)
|
||||||
self.Tensor = torch.FloatTensor
|
self.Tensor = torch.FloatTensor
|
||||||
|
|||||||
@@ -12,16 +12,17 @@ from scipy.io import loadmat
|
|||||||
from scipy.linalg import qr
|
from scipy.linalg import qr
|
||||||
from scipy.signal import filtfilt, lfilter
|
from scipy.signal import filtfilt, lfilter
|
||||||
# from numpy.linalg import _umath_linalg
|
# from numpy.linalg import _umath_linalg
|
||||||
|
from logs.log import algo_log
|
||||||
|
|
||||||
|
|
||||||
class FbccaDw:
|
class FbccaDw:
|
||||||
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
||||||
print('******************************************')
|
algo_log('******************************************', level="debug")
|
||||||
print('parameter list')
|
algo_log('parameter list', level="debug")
|
||||||
print('target:', num_target)
|
algo_log('target:', num_target, level="debug")
|
||||||
print('number of filter bank:', num_filter)
|
algo_log('number of filter bank:', num_filter, level="debug")
|
||||||
print('parameter:', parameter)
|
algo_log('parameter:', parameter, level="debug")
|
||||||
print('width:', width)
|
algo_log('width:', width, level="debug")
|
||||||
self.phase = 0
|
self.phase = 0
|
||||||
self.bandWidth = width
|
self.bandWidth = width
|
||||||
self.winNum = winNum
|
self.winNum = winNum
|
||||||
@@ -237,7 +238,7 @@ class FbccaDw:
|
|||||||
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
||||||
return np.asmatrix(dataFiltered)
|
return np.asmatrix(dataFiltered)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(Exception)
|
algo_log(f"Exception: {Exception}", level="debug")
|
||||||
|
|
||||||
'''
|
'''
|
||||||
getDataQ
|
getDataQ
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import ast
|
import ast
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import threading
|
import threading
|
||||||
|
import zmq
|
||||||
import json
|
import json
|
||||||
import queue
|
import queue
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
@@ -13,17 +14,14 @@ from Zmq.filterProcess import FilterRingBuffer
|
|||||||
from PubLibrary.InifileHelper import IniRead
|
from PubLibrary.InifileHelper import IniRead
|
||||||
from logs.log import algo_log
|
from logs.log import algo_log
|
||||||
|
|
||||||
import zmq
|
zmqServer_host = str(IniRead('system', 'zmqServer_host', '127.0.0.1'))
|
||||||
|
|
||||||
class zmqServer(threading.Thread):
|
class zmqServer(threading.Thread):
|
||||||
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
def __init__(self, host='0.0.0.0', cmd_port=8099, data_port=8100, device_info=None):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.device_info = device_info
|
self.device_info = device_info
|
||||||
|
|
||||||
self.host = host
|
self.host = zmqServer_host
|
||||||
|
|
||||||
# test_host = "192.168.254.102"
|
|
||||||
# self.host = test_host
|
|
||||||
|
|
||||||
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
self.cmd_port = cmd_port # 命令交互端口:收JSON命令 + 返JSON结果
|
||||||
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
self.data_port = data_port # 数据交互端口:收二进制原始脑电 + 返二进制滤波结果
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ Serial_port = COM44
|
|||||||
algo_log_level = DEBUG
|
algo_log_level = DEBUG
|
||||||
console_output = 1
|
console_output = 1
|
||||||
save_train_data = 0
|
save_train_data = 0
|
||||||
|
zmqServer_host = 127.0.0.1
|
||||||
|
|
||||||
; 64 导设备配置
|
; 64 导设备配置
|
||||||
[device_type_1]
|
[device_type_1]
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ echo "输出目录:${OUT_DIR}"
|
|||||||
python -m nuitka \
|
python -m nuitka \
|
||||||
--standalone \
|
--standalone \
|
||||||
--msvc=latest \
|
--msvc=latest \
|
||||||
--windows-console-mode=force \
|
--windows-console-mode=disable \
|
||||||
--module-parameter=torch-disable-jit=yes \
|
--module-parameter=torch-disable-jit=yes \
|
||||||
--enable-plugin=no-qt \
|
--enable-plugin=no-qt \
|
||||||
--include-package=numpy \
|
--include-package=numpy \
|
||||||
|
|||||||
Reference in New Issue
Block a user