del train
This commit is contained in:
39
Decoder.py
39
Decoder.py
@@ -96,7 +96,7 @@ class Decoder_main(threading.Thread):
|
||||
elif decoder_class == 'ssmvep':
|
||||
self.zmqServer.interval_init(decoder_class)
|
||||
self.n_chan = 8
|
||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2]
|
||||
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||
self.single_train = 10 # 单类别数量
|
||||
self.num_target = 2 # 分类目标数目
|
||||
@@ -268,26 +268,29 @@ class Decoder_main(threading.Thread):
|
||||
|
||||
'''训练阶段采集数据'''
|
||||
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||
if self.zmqServer.StartTrain:
|
||||
|
||||
|
||||
if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \
|
||||
self.train_epoch[1] + self.zmqServer.event_inner_idx:
|
||||
|
||||
self.currentLabel = self.zmqServer.currentLabel
|
||||
self.zmqServer.StartTrain = False
|
||||
if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \
|
||||
self.train_epoch[1] \
|
||||
+ self.zmqServer.event_inner_idx:
|
||||
|
||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||
|
||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||
self.trainLabel, list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
else:
|
||||
time.sleep(0.0001)
|
||||
return
|
||||
print('训练队列数据:', self.zmqServer.paradigmBuffer.GetDataLenCount())
|
||||
trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据
|
||||
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.zmqServer.event_inner_idx])
|
||||
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||
trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[
|
||||
0]:self.zmqServer.event_inner_idx + self.train_epoch[1]]
|
||||
print('trial: ', self.zmqServer.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||
self.trainLabel, list) \
|
||||
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||
self.trainData.append(trainTrial)
|
||||
self.trainLabel.append(self.currentLabel)
|
||||
|
||||
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
||||
if self.load_model == False: # 模型尚未训练完成
|
||||
|
||||
Reference in New Issue
Block a user