import time from psychopy import visual, core, logging # import some libraries from PsychoPy import random from datetime import datetime # LAB STREAMING LAYER1 from pylsl import StreamInfo, StreamOutlet from psychopy import event import numpy as np from DecoderDW.Server import TCPServer from DecoderDW.Client import TCPClient # import subprocess # ---------------------- # constants # size of the window WINWIDTH = 1920 WINHEIGHT = 1080 REFRESH_RATE = 144 def get_keypress(): keys = event.getKeys() if keys: return keys[0] else: return None def shutdown(win,client): client.send_data('saveData', 0) client.send_data('predict',2) win.close() core.quit() # end of configuration # ---------------------- def generate_square_wave(frequency, sampling_rate=REFRESH_RATE, duration=5): """ 生成方波序列 参数: frequency (float): 频率(Hz) sampling_rate (int): 采样率(Hz),应与屏幕刷新率一致 duration (float): 时长(秒) 返回: square_wave (list): 方波序列 """ # 计算总点数 n_points = int(duration * sampling_rate) # 生成时间序列 time = np.linspace(0, duration, n_points, endpoint=False) # 生成正弦波数据 sin_wave = np.sin(2 * np.pi * frequency * time) # 生成方波数据 square_wave = np.where(sin_wave >= 0, 1, 0) return square_wave.tolist() # 启动一个进程,不等待其完成 import os if __name__ == "__main__": # ---------------------------------------------------------------------------------- # main window settings main_win = visual.Window(size=(WINWIDTH, WINHEIGHT), units='height', screen=0, fullscr=False, gammaErrorPolicy='warn', color=(0.7, 0.7, 0.7)) print('starting 1') # Set up LabStreamingLayer stream. info = StreamInfo(name='psychopy_stimuli', type='Markers', channel_count=1, channel_format='string', source_id='psychopy_stimuli_001') outlet = StreamOutlet(info) # Broadcast the stream. imageStim1 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy.jpg') txtStim1 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, italic=False, pos=(-600, 30)) imageStim2 = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy.jpg') txtStim2 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, italic=False, pos=(0, 30)) imageStim3 = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy.jpg') txtStim3 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, italic=False, pos=(600, 30)) imageStim4 = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy.jpg') txtStim4 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, italic=False, pos=(-600, -470)) imageStim5 = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy.jpg') txtStim5 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, italic=False, pos=(0, -470)) imageStim6 = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy.jpg') txtStim6 = visual.TextStim(win=main_win, text='△', font='SimHei', height=80, color='black', units='pix', bold=True, italic=False, pos=(600, -470)) imageStim1red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, 300), units='pix', image='UI/figures/xy_red.jpg') imageStim2red = visual.ImageStim(main_win, size=(300, 300), pos=(0, 300), units='pix', image='UI/figures/xy_red.jpg') imageStim3red = visual.ImageStim(main_win, size=(300, 300), pos=(600, 300), units='pix', image='UI/figures/xy_red.jpg') imageStim4red = visual.ImageStim(main_win, size=(300, 300), pos=(-600, -200), units='pix', image='UI/figures/xy_red.jpg') imageStim5red = visual.ImageStim(main_win, size=(300, 300), pos=(0, -200), units='pix', image='UI/figures/xy_red.jpg') imageStim6red = visual.ImageStim(main_win, size=(300, 300), pos=(600, -200), units='pix', image='UI/figures/xy_red.jpg') frequencies = [25,26,27,28,29,30] #[9,10,11,12,13,14] #[30,31,32,33,34,35] [25,26,27,28,29,30] # 生成方波数据 square_wave_9 = generate_square_wave(frequencies[0], REFRESH_RATE, 5) square_wave_11 = generate_square_wave(frequencies[1], REFRESH_RATE, 5) square_wave_12 = generate_square_wave(frequencies[2], REFRESH_RATE, 5) square_wave_13 = generate_square_wave(frequencies[3], REFRESH_RATE, 5) square_wave_14 = generate_square_wave(frequencies[4], REFRESH_RATE, 5) square_wave_15 = generate_square_wave(frequencies[5], REFRESH_RATE, 5) # 创建刺激对象列表,便于管理 image_stims = [imageStim1, imageStim2, imageStim3, imageStim4, imageStim5, imageStim6] txt_stims = [txtStim1, txtStim2, txtStim3, txtStim4, txtStim5, txtStim6] square_waves = [square_wave_9, square_wave_11, square_wave_12, square_wave_13, square_wave_14, square_wave_15] time.sleep(2) # grating.color = 'black' server = TCPServer() server.start() client = TCPClient('127.0.0.1', 8099) client.connect() print('Connected decoder_main') # client.send_data('impedance', 1) # time.sleep(20) # client.send_data('impedance', 2) client.send_data('targetFreqs', frequencies) # 使用frequencies变量,确保与刺激频率一致 time.sleep(1) # 开启全程数据保存到 EEGFiles client.send_data('saveData',1) # client.send_data('impedance',1) # 实验参数 repeats = 3 seq_freq = frequencies * repeats seq_freq = np.random.permutation(seq_freq).tolist() num_trials = len(seq_freq) # 总试验次数, 6*6=36 trial_count = 0 # 在线解码精度计算 online_results = [] # 存储每个trial的解码结果 correct_predictions = 0 # 正确预测计数 # 保存序列信息 seq_info = { 'total_trials': num_trials, 'frequencies': frequencies, 'sequence': seq_freq, 'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S') } # 保存序列信息到文件 import json seq_file_path = f'EEGFiles/sequence_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json' with open(seq_file_path, 'a', encoding='utf-8') as f: json.dump(seq_info, f, ensure_ascii=False, indent=2) #========================Trials Started======================# while trial_count < num_trials: # 从序列中获取当前试验的目标频率 target_freq = seq_freq[trial_count] target_freq_index = frequencies.index(target_freq) print(f'Trials {trial_count + 1}/{num_trials} - Target Frequency: {target_freq}Hz (Label: {target_freq_index + 1})') # Stage 1: Cue Stage # print('Cue Stage: The target frequency is in Red') client.send_data('setLabelAndTrialInfo', { 'label': 0, 'trial_info': { 'trial': trial_count + 1, 'phase': 'cue', 'target_freq': target_freq } }) for frameN in range(int(1 * REFRESH_RATE)): # 1秒提示 key_press = get_keypress() if key_press in ['q']: shutdown(main_win, client) # 显示所有刺激,目标刺激为红色 for i, stim in enumerate(image_stims): if i == target_freq_index: # 目标刺激显示红色 if i == 0: imageStim1red.draw() elif i == 1: imageStim2red.draw() elif i == 2: imageStim3red.draw() elif i == 3: imageStim4red.draw() elif i == 4: imageStim5red.draw() elif i == 5: imageStim6red.draw() else: # 其他刺激显示正常颜色 stim.draw() main_win.flip() # Stage 2: Flanker Stimulus # print('Flanker Stage: flank all frequencies') client.send_data('predict', 1) client.send_data('setLabelAndTrialInfo', { 'label': target_freq_index + 1, # 设置目标频率标签 这里+1,是因为0代表不记录数据 'trial_info': { 'trial': trial_count + 1, # trial 从0开始 'phase': 'stimulus', 'target_freq': target_freq } }) outlet.push_sample(['S 1']) for frameN in range(6 * REFRESH_RATE): # 6秒刺激 key_press = get_keypress() if key_press in ['q']: shutdown(main_win, client) # 所有频率按照方波闪烁 if square_wave_9[frameN % len(square_wave_9)] == 1: imageStim1.draw() if square_wave_11[frameN % len(square_wave_11)] == 1: imageStim2.draw() if square_wave_12[frameN % len(square_wave_12)] == 1: imageStim3.draw() if square_wave_13[frameN % len(square_wave_13)] == 1: imageStim4.draw() if square_wave_14[frameN % len(square_wave_14)] == 1: imageStim5.draw() if square_wave_15[frameN % len(square_wave_15)] == 1: imageStim6.draw() main_win.flip() if server.ChoosenNum != -1: break # 记录在线解码结果 predicted_freq_index = server.ChoosenNum # 解码结果 predicted_freq = frequencies[predicted_freq_index] if predicted_freq_index != -1 else -1 # 判断解码是否正确 is_correct = (predicted_freq_index == target_freq_index) if predicted_freq_index != -1 else False if is_correct: correct_predictions += 1 # 记录trial结果 trial_result = { 'trial': trial_count + 1, 'target_freq': target_freq, 'target_freq_index': target_freq_index, 'predicted_freq': predicted_freq, 'predicted_freq_index': predicted_freq_index, 'is_correct': is_correct, 'status': 'Success' if predicted_freq_index != -1 else 'Failed' } online_results.append(trial_result) # 打印当前trial结果 status_symbol = "✓" if is_correct else "✗" if predicted_freq_index == -1: print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 解码失败 - {status_symbol}') else: print(f'Trial {trial_count + 1}: 目标{target_freq}Hz -> 预测{predicted_freq}Hz - {status_symbol}') # Stage 3: Decoding Feedback outlet.push_sample(['S 2']) client.send_data('setLabelAndTrialInfo', { 'label': 0, # 反馈阶段标签为0 'trial_info': { 'trial': trial_count + 1, 'phase': 'feedback', 'target_freq': target_freq } }) # print('反馈阶段: 显示解码结果') for frameN in range(1 * REFRESH_RATE): # 1秒反馈 key_press = get_keypress() if key_press in ['q']: shutdown(main_win, client) # 显示所有刺激但不闪烁 for stim in image_stims: stim.draw() # 显示解码结果 if server.ChoosenNum == 0: txtStim1.draw() elif server.ChoosenNum == 1: txtStim2.draw() elif server.ChoosenNum == 2: txtStim3.draw() elif server.ChoosenNum == 3: txtStim4.draw() elif server.ChoosenNum == 4: txtStim5.draw() elif server.ChoosenNum == 5: txtStim6.draw() main_win.flip() server.ChoosenNum = -1 trial_count += 1 # 计算总体在线解码精度 total_trials = len(online_results) successful_trials = len([r for r in online_results if r['status'] == 'Success']) failed_trials = len([r for r in online_results if r['status'] == 'Failed']) overall_accuracy = correct_predictions / total_trials if total_trials > 0 else 0 # Print Accuracy print(f"Total Accuracy: {overall_accuracy:.3f} ({correct_predictions}/{total_trials})") # 按频率分析准确率 print(f"\n=== 按频率分析准确率 ===") freq_accuracy = {} for result in online_results: freq = result['target_freq'] if freq not in freq_accuracy: freq_accuracy[freq] = {'correct': 0, 'total': 0, 'failed': 0} freq_accuracy[freq]['total'] += 1 if result['status'] == 'Failed': freq_accuracy[freq]['failed'] += 1 elif result['is_correct']: freq_accuracy[freq]['correct'] += 1 print(f"{'频率':<8} {'准确率':<8} {'正确/总数':<10} {'失败数':<8}") print("-" * 40) for freq in sorted(freq_accuracy.keys()): stats = freq_accuracy[freq] accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0 print(f"{freq}Hz{'':<4} {accuracy:.3f}{'':<4} {stats['correct']}/{stats['total']}{'':<6} {stats['failed']}") # 保存在线解码结果到文件 online_results_file = f'EEGFiles/online_results_{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json' online_summary = { 'total_trials': total_trials, 'successful_trials': successful_trials, 'failed_trials': failed_trials, 'correct_predictions': correct_predictions, 'overall_accuracy': overall_accuracy, # 'freq_accuracy': freq_accuracy, 'trial_results': online_results, # 'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S') } with open(online_results_file, 'w', encoding='utf-8') as f: json.dump(online_summary, f, ensure_ascii=False, indent=2) client.send_data('predict',2) # 关闭系统 main_win.close()