replace print with algo_log
This commit is contained in:
@@ -34,7 +34,7 @@ cudnn.benchmark = True
|
||||
cudnn.deterministic = True
|
||||
from sklearn.model_selection import train_test_split
|
||||
# writer = SummaryWriter('./TensorBoardX/')
|
||||
|
||||
from logs.log import algo_log
|
||||
|
||||
# Convolution module
|
||||
# use conv to capture local features, instead of postion embedding.
|
||||
@@ -318,11 +318,11 @@ class ExP():
|
||||
train_pred = torch.max(outputs, 1)[1]
|
||||
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
||||
|
||||
print('Epoch:', e,
|
||||
algo_log('Epoch:', e,
|
||||
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
||||
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
||||
' Train accuracy %.6f' % train_acc,
|
||||
' Test accuracy is %.6f' % acc)
|
||||
' Test accuracy is %.6f' % acc, level="debug")
|
||||
|
||||
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
||||
num = num + 1
|
||||
@@ -335,8 +335,8 @@ class ExP():
|
||||
|
||||
torch.save(self.model, model_path)
|
||||
averAcc = averAcc / num
|
||||
print('The average accuracy is:', averAcc)
|
||||
print('The best accuracy is:', bestAcc)
|
||||
algo_log('The average accuracy is:', averAcc, level="debug")
|
||||
algo_log('The best accuracy is:', bestAcc, level="debug")
|
||||
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
||||
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
||||
|
||||
@@ -346,10 +346,10 @@ class ExP():
|
||||
|
||||
def onlineTrain(data_queue,result_queue):
|
||||
import torch
|
||||
print(f"[DEBUG] torch.__version__ = {torch.__version__}")
|
||||
print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}")
|
||||
algo_log(f"[DEBUG] torch.__version__ = {torch.__version__}", level="debug")
|
||||
algo_log(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}", level="debug")
|
||||
if torch.cuda.is_available():
|
||||
print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}")
|
||||
algo_log(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}", level="debug")
|
||||
try:
|
||||
starttime = datetime.datetime.now()
|
||||
|
||||
@@ -366,12 +366,12 @@ def onlineTrain(data_queue,result_queue):
|
||||
data = data_queue.get(timeout=30)
|
||||
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
||||
exp = ExP(n_chan)
|
||||
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
|
||||
algo_log('训练参数: ',np.shape(all_data),np.shape(all_label),model_path, level="debug")
|
||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
||||
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
|
||||
|
||||
endtime = datetime.datetime.now()
|
||||
print('train duration: ',str(endtime - starttime))
|
||||
algo_log('train duration: ',str(endtime - starttime), level="debug")
|
||||
|
||||
# 将模型或参数传回
|
||||
result_queue.put({
|
||||
@@ -387,7 +387,7 @@ def offlineTrain(all_data,all_label,modelPath):
|
||||
|
||||
# seed_n = np.random.randint(2025)
|
||||
seed_n = 1877
|
||||
print('seed is ' + str(seed_n))
|
||||
algo_log('seed is ' + str(seed_n), level="debug")
|
||||
random.seed(seed_n)
|
||||
np.random.seed(seed_n)
|
||||
torch.manual_seed(seed_n)
|
||||
@@ -397,13 +397,12 @@ def offlineTrain(all_data,all_label,modelPath):
|
||||
exp = ExP()
|
||||
|
||||
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
||||
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||
algo_log('THE BEST ACCURACY IS ' + str(bestAcc), level="debug")
|
||||
|
||||
endtime = datetime.datetime.now()
|
||||
print('train duration: ',str(endtime - starttime))
|
||||
algo_log('train duration: ',str(endtime - starttime), level="debug")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(time.asctime(time.localtime(time.time())))
|
||||
print(time.asctime(time.localtime(time.time())))
|
||||
algo_log(f"[DEBUG] time.asctime(time.localtime(time.time())) = {time.asctime(time.localtime(time.time()))}", level="debug")
|
||||
|
||||
Reference in New Issue
Block a user