original push
This commit is contained in:
16
algorithm_V0/datacollect/RunOnce.py
Normal file
16
algorithm_V0/datacollect/RunOnce.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
|
||||
def is_program_running(name='Global\\Parser_main'):
|
||||
# 创建互斥体
|
||||
mutex_name =name
|
||||
h_mutex = ctypes.windll.kernel32.CreateMutexW(None, False, mutex_name)
|
||||
|
||||
# 检查互斥体是否已经存在
|
||||
if ctypes.windll.kernel32.GetLastError() == 183: # ERROR_ALREADY_EXISTS
|
||||
print("程序已经在运行.")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
379
algorithm_V0/datacollect/SunnyLinker.py
Normal file
379
algorithm_V0/datacollect/SunnyLinker.py
Normal file
@@ -0,0 +1,379 @@
|
||||
# -*-coding:utf-8 -*-
|
||||
'''
|
||||
SunnyLinker的通讯驱动
|
||||
'''
|
||||
import ast
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import datetime
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from threading import Thread, Event
|
||||
import serial
|
||||
from scipy import signal
|
||||
from serial.serialutil import SerialException
|
||||
|
||||
from protocol import ProtocolFrame
|
||||
|
||||
class RingBuffer:
|
||||
def __init__(self, n_chan, n_points):
|
||||
self.n_chan = n_chan
|
||||
self.n_points = n_points
|
||||
self.buffer = np.zeros((n_chan, n_points))
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0
|
||||
self.nUpdate = 0
|
||||
self.rawData = np.zeros((n_chan, 1))
|
||||
|
||||
## append buffer and update current pointer
|
||||
def appendBuffer(self, data):
|
||||
if self.nUpdate == self.n_points:
|
||||
raise Exception("Buffer is full")
|
||||
|
||||
n = data.shape[1]
|
||||
|
||||
# 计算可以写入的元素数量
|
||||
write_count = min(self.n_points - self.nUpdate, n)
|
||||
# 写入新数据
|
||||
self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count]
|
||||
# 更新结束指针
|
||||
self.currentPtr = (self.currentPtr + write_count) % self.n_points
|
||||
# 更新大小
|
||||
self.nUpdate += write_count
|
||||
|
||||
## get data from buffer
|
||||
def getData(self, count=50):
|
||||
# 确保不会尝试读取超过缓冲区当前大小的数据
|
||||
count = min(count, self.nUpdate)
|
||||
|
||||
# 计算读取结束后的下一个位置
|
||||
next_read_ptr = (self.readPtr + count) % self.n_points
|
||||
if self.readPtr + count <= self.n_points:
|
||||
# 情况 1:不环绕,数据是连续的
|
||||
end_index = next_read_ptr if next_read_ptr != 0 else self.n_points
|
||||
data = self.buffer[:, self.readPtr:end_index]
|
||||
else:
|
||||
# 情况 2:发生环绕,数据被分成两部分
|
||||
# 第一部分:从 readPtr 到缓冲区末尾
|
||||
part1 = self.buffer[:, self.readPtr:]
|
||||
# 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点
|
||||
part2 = self.buffer[:, :next_read_ptr]
|
||||
# 将两部分在列方向上拼接
|
||||
data = np.concatenate((part1, part2), axis=1)
|
||||
|
||||
# 更新读指针
|
||||
self.readPtr = next_read_ptr
|
||||
# 更新大小
|
||||
self.nUpdate -= count
|
||||
return data
|
||||
|
||||
# reset buffer
|
||||
def resetAllPara(self):
|
||||
self.nUpdate = 0
|
||||
self.currentPtr = 0
|
||||
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||
|
||||
|
||||
class SunnyLinker64(Thread, ):
|
||||
t_buffer = 10
|
||||
n_chan = 64
|
||||
srate = 250
|
||||
receiveData = b''
|
||||
toUv=True#转为uV
|
||||
RingBufferLock = threading.Lock()
|
||||
|
||||
# 单例模式
|
||||
_instance = None
|
||||
_initialized = False # 检查是否已经初始化
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SunnyLinker64, cls).__new__(cls)
|
||||
return cls._instance
|
||||
def __init__(self, host='127.0.0.1', port=7878, srate=250, n_chan=64,method = 'tcp'):
|
||||
if SunnyLinker64._initialized:
|
||||
return
|
||||
Thread.__init__(self)
|
||||
self.daemon = True
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.srate = srate
|
||||
self.n_chan = n_chan
|
||||
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
|
||||
self.__ringBuffer = RingBuffer(self.n_chan + 2,
|
||||
int(np.round(self.t_buffer * self.srate)))
|
||||
self.energy = 0 # 电量
|
||||
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||
self.gain_value = 6 # 增益倍数
|
||||
|
||||
# 设置初始化标志为True,防止重复初始化
|
||||
SunnyLinker64._initialized = True
|
||||
|
||||
# --- 新增:用于心跳检测 ---
|
||||
self.last_called = 0 # 初始化为0
|
||||
self.last_called_lock = threading.Lock() # 保护 last_called 的访问
|
||||
|
||||
|
||||
|
||||
def set_sampleRate(self,sampleRate_Code=0x00):
|
||||
'''
|
||||
设置采样率
|
||||
:param sampleRate_Code: 0x00:250Hz,0x01:500Hz,0x02:1000Hz,0x03:2000Hz
|
||||
'''
|
||||
function_code = 0x02
|
||||
gain_code = 0x06
|
||||
sampleRate_Code = [gain_code,sampleRate_Code]
|
||||
packed_data = ProtocolFrame.pack(function_code, sampleRate_Code)
|
||||
if self.method == 'tcp':
|
||||
self.sock.send(packed_data)
|
||||
|
||||
def push_trigger(self,label):
|
||||
'''
|
||||
数据打标
|
||||
@param label:标签类别
|
||||
'''
|
||||
function_code = None
|
||||
label = [label]
|
||||
packed_data = ProtocolFrame.pack(function_code, label)
|
||||
if self.method == 'tcp' and hasattr(self,'serial'):
|
||||
print('发送:', label, datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
|
||||
self.serial.write(packed_data)
|
||||
def Impedance(self, On):
|
||||
'''
|
||||
阻抗检测开关
|
||||
:param On:True为开启,False为关闭
|
||||
:return: 组好的协议帧
|
||||
'''
|
||||
function_code = 0x01
|
||||
if On:
|
||||
data = [0x1]
|
||||
self.gain_value = 6
|
||||
else:
|
||||
data = [0x0]
|
||||
self.gain_value = 6
|
||||
packed_data = ProtocolFrame.pack(function_code, data)
|
||||
if self.method == 'tcp':
|
||||
self.sock.send(packed_data)
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
if self.method == 'serial':
|
||||
# 开启com口,波特率115200,超时5
|
||||
self.sock = serial.Serial(self.host, self.port, timeout=5)
|
||||
self.sock.flushInput() # 清空缓冲区
|
||||
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||
while not count:
|
||||
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||
# # 接收和存储数据
|
||||
data = (self.sock.read(count))
|
||||
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||
elif self.method == 'tcp':
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.sock.connect((self.host, int(self.port)))
|
||||
self.set_sampleRate(0x00) #设置250Hz采样率
|
||||
except Exception as e:
|
||||
print("请打开头环")
|
||||
print(e)
|
||||
|
||||
print("connected")
|
||||
|
||||
def extract_packet(self, packet):
|
||||
# 存储一个点的八通道数据
|
||||
dataList = []
|
||||
# 存储116个点的八通道数据
|
||||
dataMatrix = []
|
||||
|
||||
for j in range(5):
|
||||
for i in range(self.n_chan):
|
||||
if not self.toUv:#原始数据直接输出
|
||||
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||
194 * j + 25 + 2 + i * 3]
|
||||
|
||||
else:#转为uV
|
||||
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||
194 * j + 25 + 2 + i * 3]
|
||||
if val < 8388608:
|
||||
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||
else:
|
||||
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||
dataList.append(val)
|
||||
#同步触发源
|
||||
val = packet[194 * j + 25 + (i+1) * 3]
|
||||
dataList.append(val)
|
||||
#同步触发序号
|
||||
val = packet[194 * j + 25 + (i+1) * 3+1]
|
||||
dataList.append(val)
|
||||
|
||||
|
||||
# 将数据矩阵进行拼接
|
||||
if len(dataMatrix) == 0:
|
||||
dataMatrix = np.asmatrix(dataList)
|
||||
else:
|
||||
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
|
||||
dataList.clear()
|
||||
return np.transpose(dataMatrix)
|
||||
|
||||
def run(self):
|
||||
self.connect()
|
||||
self.running = True
|
||||
self.PackageLength = 998
|
||||
# 启动心跳检测线程
|
||||
threading.Thread(target=self.heartbeat_checker, daemon=True).start()
|
||||
while self.running:
|
||||
try:
|
||||
if self.method == 'serial':
|
||||
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||
if count:
|
||||
# 接收和存储数据
|
||||
data = (self.sock.read(count))
|
||||
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||
elif self.method == 'tcp':
|
||||
data = self.sock.recv(600)
|
||||
if not data:
|
||||
break
|
||||
self.receiveData += data
|
||||
with self.last_called_lock:
|
||||
self.last_called = time.time()
|
||||
self.status_code = 1 # 收到数据,标记为正常
|
||||
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
|
||||
b'\x55\x55') >= self.PackageLength - 2:
|
||||
|
||||
index = self.receiveData.index(b'\xaa')
|
||||
self.receiveData = self.receiveData[index:]
|
||||
if len(self.receiveData) >= self.PackageLength:
|
||||
onepackage = self.receiveData[:self.PackageLength]
|
||||
if onepackage[7] != 0:
|
||||
self.energy = onepackage[7] # 电量
|
||||
self.receiveData = self.receiveData[self.PackageLength:]
|
||||
dataMatrix = self.extract_packet(onepackage)
|
||||
try:
|
||||
with self.RingBufferLock:
|
||||
self.__ringBuffer.appendBuffer(dataMatrix)
|
||||
except Exception as e:
|
||||
print("锁:写入异常",e)
|
||||
# self.RingBufferLock.release()
|
||||
except ConnectionResetError:
|
||||
self.status_code = 0 # 状态异常
|
||||
print("Connection was reset by the peer.")
|
||||
break
|
||||
self.sock.close()
|
||||
|
||||
# --- 新增:心跳检测线程 ---
|
||||
def heartbeat_checker(self):
|
||||
"""
|
||||
定期检查是否在最近2秒内收到 eegData
|
||||
如果超过2秒未收到,则设置 status_code = 0
|
||||
"""
|
||||
while self.running:
|
||||
time.sleep(0.5) # 每0.5秒检查一次
|
||||
with self.last_called_lock:
|
||||
now = time.time()
|
||||
# 只有收到过一次数据后才开始判断超时
|
||||
if self.last_called > 0 and (now - self.last_called) > 2:
|
||||
if self.status_code != 0:
|
||||
print("EEG data timeout: disconnected")
|
||||
self.status_code = 0
|
||||
def getImpedance(self, data,n_chan):
|
||||
'''
|
||||
获取阻抗值,已经放大100倍,单位是kΩ
|
||||
@param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来
|
||||
@return: 返回各个通道的阻抗值
|
||||
'''
|
||||
impedanceList = []
|
||||
data = data[:n_chan]
|
||||
for channelindex in range(data.shape[0]):
|
||||
if len(data[channelindex]) > 0:
|
||||
data_list = []
|
||||
# 设计陷波滤波器,去除50Hz成分
|
||||
is50filter = True
|
||||
if is50filter:
|
||||
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率
|
||||
data_list = signal.lfilter(b, a, data[channelindex].tolist())
|
||||
|
||||
else:
|
||||
data_list.extend(data[channelindex].tolist())
|
||||
|
||||
data_list = data_list[-1000:]
|
||||
# 执行FFT
|
||||
fft_result = np.fft.fft(data_list)
|
||||
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
|
||||
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
|
||||
|
||||
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
|
||||
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
|
||||
# [fft_magnitude[-1] / len(t[0].tolist())]))
|
||||
|
||||
# 找到幅值最大的频率成分的索引(忽略直流分量,即索引0)
|
||||
max_index = np.argmax(fft_magnitude[1:])
|
||||
|
||||
# 获取最大幅值的频率索引(加上1,因为索引0是直流分量)
|
||||
freq_index = max_index + 1
|
||||
|
||||
# 获取最大幅值
|
||||
max_magnitude = fft_magnitude[freq_index]
|
||||
|
||||
# 阻抗
|
||||
import math
|
||||
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
|
||||
result *= 0.44 * 100 # 统一放大100倍
|
||||
impedanceList.append(int(result))
|
||||
# print(max_magnitude, result)
|
||||
else:
|
||||
impedanceList.append(0)
|
||||
impedances = np.array(impedanceList)
|
||||
return impedances
|
||||
def getData(self,count):
|
||||
'''
|
||||
获取最新的数据
|
||||
@param count: 每通道返回的最数值数目
|
||||
@return: 所有通道的最新count个数值
|
||||
'''
|
||||
data=None
|
||||
try:
|
||||
with self.RingBufferLock:
|
||||
data = self.__ringBuffer.getData(count)
|
||||
except:
|
||||
print("锁:读取异常")
|
||||
# self.RingBufferLock.release()
|
||||
|
||||
|
||||
return data
|
||||
def GetDataLenCount(self):
|
||||
'''
|
||||
获取最新缓存中每个通道的数量
|
||||
@return:
|
||||
'''
|
||||
return self.__ringBuffer.nUpdate
|
||||
|
||||
def ResetAll(self):
|
||||
'''
|
||||
清空缓存
|
||||
@return:
|
||||
'''
|
||||
with self.RingBufferLock:
|
||||
self.__ringBuffer.resetAllPara()
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Usage
|
||||
Linker = SunnyLinker64('127.0.0.1', 5086, 1000, 65)
|
||||
Linker.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(0.005)
|
||||
if(Linker.count()>0):
|
||||
# print(Linker.ringBuffer.nUpdate)
|
||||
t = Linker.getData()
|
||||
print(t.shape[1], Linker.count())
|
||||
# Linker.ringBuffer.nUpdate=0
|
||||
# time.sleep(0.2)
|
||||
except KeyboardInterrupt:
|
||||
Linker.stop()
|
||||
|
||||
|
||||
|
||||
113
algorithm_V0/datacollect/build_algorithm.spec
Normal file
113
algorithm_V0/datacollect/build_algorithm.spec
Normal file
@@ -0,0 +1,113 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
|
||||
import sys
|
||||
import os
|
||||
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
|
||||
|
||||
# ========================================================
|
||||
# 1. 工程配置区 (Project Config)
|
||||
# ========================================================
|
||||
block_cipher = None
|
||||
ENTRY_POINT = 'start_parse.py'
|
||||
APP_NAME = 'start_parse' # 打包后生成的文件夹名和 exe 名
|
||||
|
||||
# ========================================================
|
||||
# 2. 依赖分析 (Dependency Analysis)
|
||||
# ========================================================
|
||||
hidden_imports = [
|
||||
# eegParser 依赖
|
||||
'numpy',
|
||||
'numpy.lib.stride_tricks',
|
||||
'pandas',
|
||||
'scipy',
|
||||
'scipy.io',
|
||||
'scipy.io.savemat',
|
||||
'scipy.signal',
|
||||
|
||||
# SunnyLinker 依赖
|
||||
'serial',
|
||||
'serial.serialutil',
|
||||
'socket',
|
||||
|
||||
# zmq 通信依赖
|
||||
'zmq',
|
||||
'zmq.asyncio',
|
||||
|
||||
# 其他可能遗漏的模块
|
||||
'threading',
|
||||
'datetime',
|
||||
]
|
||||
|
||||
# 收集 zmq 的所有子模块
|
||||
try:
|
||||
hidden_imports += collect_submodules('zmq')
|
||||
except:
|
||||
pass
|
||||
|
||||
# ========================================================
|
||||
# 3. 资源锚定 (Data Anchoring)
|
||||
# ========================================================
|
||||
# 打包时需要包含的资源文件
|
||||
datas = [
|
||||
('xy_64.xlsx', '.'), # 电极位置文件
|
||||
]
|
||||
|
||||
# 收集 mne 的数据文件(如果有)
|
||||
try:
|
||||
datas += collect_data_files('mne')
|
||||
except:
|
||||
pass
|
||||
|
||||
# ========================================================
|
||||
# 4. 构建流程 (Build Process)
|
||||
# ========================================================
|
||||
a = Analysis(
|
||||
[ENTRY_POINT],
|
||||
pathex=[],
|
||||
binaries=[],
|
||||
datas=datas,
|
||||
hiddenimports=hidden_imports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython', 'matplotlib'],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
[],
|
||||
exclude_binaries=True,
|
||||
name=APP_NAME,
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
|
||||
# ========================================================
|
||||
# 5. 打包模式: OneDir (单文件夹)
|
||||
# ========================================================
|
||||
coll = COLLECT(
|
||||
exe,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
name=APP_NAME,
|
||||
)
|
||||
76
algorithm_V0/datacollect/build_datacollect.py
Normal file
76
algorithm_V0/datacollect/build_datacollect.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
打包脚本 - datacollect
|
||||
用于将 EEG 数据采集程序打包为独立的 exe 文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
def main():
|
||||
# 1. 定义路径
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
DIST_DIR = os.path.join(BASE_DIR, 'dist')
|
||||
BUILD_DIR = os.path.join(BASE_DIR, 'build')
|
||||
APP_NAME = 'start_parse'
|
||||
|
||||
# 2. 清理旧构建
|
||||
print("[1/3] Cleaning up old builds...")
|
||||
for dir_path in [DIST_DIR, BUILD_DIR]:
|
||||
if os.path.exists(dir_path):
|
||||
try:
|
||||
shutil.rmtree(dir_path)
|
||||
print(f" Cleaned {os.path.basename(dir_path)}/")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not clean {dir_path}: {e}")
|
||||
|
||||
# 3. 检查必要文件
|
||||
print("\n[2/3] Checking required files...")
|
||||
required_files = ['start_parse.py', 'eegParser.py', 'build_algorithm.spec', 'xy_64.xlsx']
|
||||
for f in required_files:
|
||||
path = os.path.join(BASE_DIR, f)
|
||||
if os.path.exists(path):
|
||||
print(f" ✓ {f}")
|
||||
else:
|
||||
print(f" ✗ {f} NOT FOUND!")
|
||||
sys.exit(1)
|
||||
|
||||
# 4. 运行 PyInstaller
|
||||
print("\n[3/3] Running PyInstaller...")
|
||||
spec_file = os.path.join(BASE_DIR, 'build_algorithm.spec')
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m", "PyInstaller",
|
||||
spec_file,
|
||||
"--clean",
|
||||
"--noconfirm"
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.check_call(cmd, cwd=BASE_DIR)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"\n✗ PyInstaller failed with error code: {e.returncode}")
|
||||
sys.exit(1)
|
||||
|
||||
# 5. 验证结果
|
||||
exe_path = os.path.join(DIST_DIR, APP_NAME, f'{APP_NAME}.exe')
|
||||
if os.path.exists(exe_path):
|
||||
size_mb = os.path.getsize(exe_path) / (1024 * 1024)
|
||||
print(f"\n{'='*50}")
|
||||
print(f"✓ SUCCESS! Executable created:")
|
||||
print(f" {exe_path}")
|
||||
print(f" Size: {size_mb:.1f} MB")
|
||||
print(f"{'='*50}")
|
||||
print(f"\n部署说明:")
|
||||
print(f" 1. 复制 dist/start_parse 文件夹到目标电脑")
|
||||
print(f" 2. 确保目标电脑已安装 EEG 设备的 USB 驱动")
|
||||
print(f" 3. 运行 start_parse.exe")
|
||||
else:
|
||||
print("\n✗ Build failed - executable not found")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
72
algorithm_V0/datacollect/build_with_copy.py
Normal file
72
algorithm_V0/datacollect/build_with_copy.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
def main():
|
||||
# 1. 定义路径
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
DIST_DIR = os.path.join(BASE_DIR, 'dist')
|
||||
APP_NAME = 'Depression_Decoder'
|
||||
|
||||
MODEL_SRC = os.path.join(BASE_DIR, 'model')
|
||||
RAW_DATA_SRC = os.path.join(BASE_DIR, 'raw_data')
|
||||
|
||||
# 2. 清理旧构建
|
||||
print("[1/3] Cleaning up old builds...")
|
||||
if os.path.exists(DIST_DIR):
|
||||
try:
|
||||
shutil.rmtree(DIST_DIR)
|
||||
print(" Cleaned dist/")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not clean dist/: {e}")
|
||||
|
||||
BUILD_DIR = os.path.join(BASE_DIR, 'build')
|
||||
if os.path.exists(BUILD_DIR):
|
||||
try:
|
||||
shutil.rmtree(BUILD_DIR)
|
||||
print(" Cleaned build/")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not clean build/: {e}")
|
||||
|
||||
# 3. 运行 PyInstaller
|
||||
print("[2/3] Running PyInstaller...")
|
||||
cmd = [
|
||||
"pyinstaller",
|
||||
"build_algorithm.spec",
|
||||
"--clean",
|
||||
"--noconfirm"
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.check_call(cmd, shell=True)
|
||||
except subprocess.CalledProcessError:
|
||||
print("Error: PyInstaller failed.")
|
||||
sys.exit(1)
|
||||
|
||||
# 4. 复制外部资源 (如果存在)
|
||||
print("[3/3] Copying external resources...")
|
||||
|
||||
# 确保 dist 目录存在 (pyinstaller 应该已经创建了)
|
||||
if not os.path.exists(DIST_DIR):
|
||||
os.makedirs(DIST_DIR)
|
||||
|
||||
for src_path, folder_name in [(MODEL_SRC, 'model'), (RAW_DATA_SRC, 'raw_data')]:
|
||||
dst_path = os.path.join(DIST_DIR, folder_name)
|
||||
if os.path.exists(src_path):
|
||||
try:
|
||||
if os.path.exists(dst_path):
|
||||
shutil.rmtree(dst_path)
|
||||
shutil.copytree(src_path, dst_path)
|
||||
print(f" Copied {folder_name} to dist/")
|
||||
except Exception as e:
|
||||
print(f" Error copying {folder_name}: {e}")
|
||||
else:
|
||||
print(f" Note: {folder_name} source not found at {src_path}, skipping.")
|
||||
|
||||
print("\n" + "="*50)
|
||||
print(f"SUCCESS! Executable is in: {DIST_DIR}")
|
||||
print("="*50)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
427
algorithm_V0/datacollect/eegParser.py
Normal file
427
algorithm_V0/datacollect/eegParser.py
Normal file
@@ -0,0 +1,427 @@
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from SunnyLinker import SunnyLinker64
|
||||
from zmqServer import zmqServer
|
||||
from zmqClient import zmqClient
|
||||
from scipy.io import savemat
|
||||
from scipy import signal
|
||||
|
||||
class Parser_main(threading.Thread):
|
||||
def __init__(self):
|
||||
threading.Thread.__init__(self)
|
||||
self.Running = True
|
||||
self.fs = 250 # 采样率
|
||||
self.energy = 0 # 电量
|
||||
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||
self.n_chan = 64
|
||||
self.dataBuffer = []
|
||||
self.file_num = 0#保存文件序号
|
||||
self.subject_id = None #受试者ID
|
||||
self.session_id = None #Session ID
|
||||
self.last_print_time = None
|
||||
|
||||
# 预处理参数
|
||||
self.enable_preprocess = True # 是否启用预处理
|
||||
self.lowcut = 0.5 # 高通滤波截止频率 (Hz)
|
||||
self.highcut = 50 # 低通滤波截止频率 (Hz)
|
||||
self.notch_freq = 50 # 工频陷波频率 (Hz)
|
||||
self.ref_chan_name = 'CPZ' # 参考电极名称
|
||||
self.ref_chan_idx = None # 参考电极索引(运行时确定)
|
||||
self._init_filter_cache() # 初始化滤波器缓存
|
||||
|
||||
# 单位转换参数
|
||||
self.calibration_scale = 1.0 # 校准系数,用于修正单位转换误差
|
||||
self.calibration_offset = 0 # 校准偏移量
|
||||
self._conversion_verified = False # 是否已验证转换
|
||||
|
||||
def connect(self):
|
||||
self.thread_data_server = SunnyLinker64('127.0.0.1', 7878, 250, 64,
|
||||
method='tcp')
|
||||
self.thread_data_server.toUv = True
|
||||
self.thread_data_server.start()
|
||||
|
||||
self.zmqServer = zmqServer()
|
||||
self.zmqServer.start()
|
||||
self.zmqClient = zmqClient('127.0.0.1', 8088)
|
||||
self.zmqClient.connect()
|
||||
|
||||
def run(self):
|
||||
while self.Running:
|
||||
# 同步信息
|
||||
if self.zmqServer.state_mode == 'sync':
|
||||
self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||
self.zmqServer.state_mode = 'rest'
|
||||
# 状态异常,报告上位机
|
||||
if self.status_code != self.thread_data_server.status_code:
|
||||
self.status_code = self.thread_data_server.status_code
|
||||
self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||
|
||||
# 返回电量
|
||||
if self.energy != self.thread_data_server.energy:
|
||||
self.energy = self.thread_data_server.energy
|
||||
self.zmqClient.send_to_all('energy', int(self.energy))
|
||||
|
||||
# 更新文件序号
|
||||
if self.subject_id != self.zmqServer.subject_id or self.session_id != self.zmqServer.session_id:
|
||||
self.subject_id = self.zmqServer.subject_id
|
||||
self.session_id = self.zmqServer.session_id
|
||||
self.file_num = 0 #从零开始计数
|
||||
|
||||
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||
self.thread_data_server.Impedance(True)
|
||||
self.zmqServer.open_Impedance = -1
|
||||
elif self.zmqServer.open_Impedance == False:
|
||||
self.thread_data_server.Impedance(False)
|
||||
self.zmqServer.open_Impedance = -1
|
||||
|
||||
if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||
if self.thread_data_server.GetDataLenCount() > self.fs:
|
||||
Impe_data = self.thread_data_server.getData(self.fs)
|
||||
# 计算阻抗
|
||||
imps = self.thread_data_server.getImpedance(Impe_data, self.n_chan)
|
||||
self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||
else:
|
||||
pass
|
||||
if self.thread_data_server.GetDataLenCount() < 50:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
data = self.thread_data_server.getData(50)
|
||||
data = data[:self.n_chan, :]
|
||||
|
||||
# 数据质量检查与预处理
|
||||
if self.enable_preprocess:
|
||||
# 1. 首先验证和校准单位转换
|
||||
data, calibrated = self.verify_and_calibrate_unit(data)
|
||||
if calibrated:
|
||||
print('[INFO] 单位转换已自动校准')
|
||||
|
||||
# 2. 检查数据质量
|
||||
issues = self.check_data_quality(data)
|
||||
if issues:
|
||||
print('[警告] 检测到数据质量问题:')
|
||||
for issue in issues:
|
||||
print(f' - {issue}')
|
||||
print('[INFO] 正在进行信号预处理...')
|
||||
|
||||
# 3. 执行预处理
|
||||
data = self.preprocess_data(data)
|
||||
|
||||
# 4. 预处理后验证
|
||||
if issues:
|
||||
new_issues = self.check_data_quality(data)
|
||||
if not new_issues:
|
||||
print(f'[INFO] 预处理完成,数据幅度正常: {np.max(np.abs(data)):.2f} µV')
|
||||
else:
|
||||
print('[警告] 预处理后仍存在问题:')
|
||||
for issue in new_issues:
|
||||
print(f' - {issue}')
|
||||
|
||||
if self.zmqServer.mat_generate:
|
||||
# 检测是否需要重置缓冲区(第二次发送 matGenerate 时清空旧数据)
|
||||
if self.zmqServer.reset_mat_buffer:
|
||||
self.dataBuffer = []
|
||||
self.last_print_time = None
|
||||
self.zmqServer.reset_mat_buffer = False
|
||||
print('[INFO] 数据缓冲区已重置,从头开始采集')
|
||||
|
||||
self.dataBuffer.append(data)
|
||||
if len(self.dataBuffer) % 50 == 0:
|
||||
current_time = time.time()
|
||||
if self.last_print_time is not None:
|
||||
elapsed_time = current_time - self.last_print_time
|
||||
# 2500个点 = 50个数据块 * 50个采样点/数据块
|
||||
actual_fs = 2500 / elapsed_time
|
||||
print(f"接收 2500 个采样点耗时: {elapsed_time:.4f} 秒, 折合实际采样率: {actual_fs:.2f} Hz")
|
||||
else:
|
||||
print("开始计时...")
|
||||
self.last_print_time = current_time
|
||||
print('数据保存进度: {}/{}'.format(len(self.dataBuffer),int(self.zmqServer.save_win*self.fs//50)))
|
||||
if len(self.dataBuffer) >= int(self.zmqServer.save_win*self.fs//50): #5分钟*60秒*250Hz / 50
|
||||
self.zmqServer.mat_generate = False
|
||||
matData = np.hstack(self.dataBuffer[:int(self.zmqServer.save_win*self.fs//50)])
|
||||
self.dataBuffer = []
|
||||
self.last_print_time = None # 重置计时器以备下次使用
|
||||
self.pack2mat(matData,self.subject_id,self.session_id)
|
||||
|
||||
def pack2mat(self,data,subject_id,session_id):
|
||||
#EEG数据
|
||||
Data = data.T
|
||||
#通道名称
|
||||
channel_names = np.array(
|
||||
['AIN1', 'AIN2', 'AIN3', 'AIN4', 'AIN5', 'AIN6', 'AIN7', 'AIN8', 'AIN9', 'AIN10', 'AIN11', 'AIN12',
|
||||
'AIN13', 'AIN14', 'AIN15', 'AIN16', 'AIN17', 'AIN18', 'AIN19', 'AIN20', 'AIN21', 'AIN22', 'AIN23',
|
||||
'AIN24', 'AIN25', 'AIN26', 'AIN27', 'AIN28', 'AIN29', 'AIN30', 'AIN31', 'AIN32', 'AIN33', 'AIN34',
|
||||
'AIN35', 'AIN36', 'AIN37', 'AIN38', 'AIN39', 'AIN40', 'AIN41', 'AIN42', 'AIN43', 'AIN44', 'AIN45',
|
||||
'AIN46', 'AIN47', 'AIN48', 'AIN49', 'AIN50', 'AIN51', 'AIN52', 'AIN53', 'AIN54', 'AIN55', 'AIN56',
|
||||
'AIN57', 'AIN58', 'AIN59', 'AIN60', 'AIN61', 'AIN62', 'AIN63', 'AIN64'], dtype=object)
|
||||
#采样率
|
||||
sample_rate = self.fs
|
||||
#通道数量
|
||||
node_number = Data.shape[1]
|
||||
# 时间轴
|
||||
t = np.linspace(0, self.zmqServer.save_win, Data.shape[0])
|
||||
t = t.reshape(len(t), 1)
|
||||
#电极名称
|
||||
electrode_name = np.array(['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1',
|
||||
'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6',
|
||||
'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1','T7',
|
||||
'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ','PZ',
|
||||
'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1'],
|
||||
dtype=object)
|
||||
#电极三维坐标
|
||||
electrode_xyz = self.read_ch_pos()
|
||||
electrode_xyz.update({'A1': [-0.095, 0, -0.005]})
|
||||
electrode_xyz = {key: electrode_xyz[key] for key in electrode_name}
|
||||
electrode_xyz = np.array(list(electrode_xyz.values()))
|
||||
#电极坐标所属的坐标系
|
||||
electrode_coord_system = '10-20 spherical model'
|
||||
#受试者ID
|
||||
Subject_id = subject_id
|
||||
#Session ID
|
||||
Session_id = session_id
|
||||
#参考电极方案
|
||||
ref = 'CPZ'
|
||||
#数据采集开始时间
|
||||
start_time = 0
|
||||
|
||||
meta_struct = {
|
||||
'subject_id': Subject_id,
|
||||
'session_id': Session_id,
|
||||
'ref': ref,
|
||||
'start_time': start_time
|
||||
}
|
||||
|
||||
eeg_struct = {
|
||||
'data': Data,
|
||||
'chn': channel_names,
|
||||
'sample_rate': sample_rate,
|
||||
'node_number': node_number,
|
||||
't': t,
|
||||
'electrode_name': electrode_name,
|
||||
'electrode_xyz': electrode_xyz,
|
||||
'electrode_coord_system': electrode_coord_system,
|
||||
'meta': meta_struct,
|
||||
}
|
||||
|
||||
fileDir = os.path.join('EEGfiles/',Subject_id,Session_id)
|
||||
os.makedirs(fileDir,exist_ok=True)
|
||||
filePath = os.path.join(fileDir,'eeg_data{}.mat'.format(self.file_num))
|
||||
# 保存到 .mat 文件,顶层变量名为 'eeg'
|
||||
savemat(filePath, {'eeg': eeg_struct})
|
||||
print('EEGfile saved at {}'.format(filePath))
|
||||
self.zmqClient.send_to_all('filePath', filePath)
|
||||
self.file_num += 1
|
||||
|
||||
def read_ch_pos(self,file_path=r'xy_64.xlsx'):
|
||||
"""
|
||||
将电极位置信息转换为Dict
|
||||
|
||||
参数:
|
||||
file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'列
|
||||
|
||||
"""
|
||||
if getattr(sys, 'frozen', False):
|
||||
script_dir = sys._MEIPASS
|
||||
else:
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(script_dir, file_path)
|
||||
df = pd.read_excel(file_path)
|
||||
# 确保列名正确
|
||||
if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']):
|
||||
raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'列")
|
||||
# 创建电极位置字典
|
||||
ch_pos = {}
|
||||
for _, row in df.iterrows():
|
||||
ch_pos[row['channel']] = [row['x'], row['y'], row['z']]
|
||||
return ch_pos
|
||||
|
||||
def _init_filter_cache(self):
|
||||
"""初始化滤波器系数缓存"""
|
||||
self._filter_cache = {
|
||||
'highpass': None,
|
||||
'lowpass': None,
|
||||
'notch': None
|
||||
}
|
||||
self._cache_valid = False
|
||||
|
||||
def _design_filters(self):
|
||||
"""设计滤波器系数"""
|
||||
if self._cache_valid:
|
||||
return
|
||||
|
||||
nyquist = self.fs / 2
|
||||
fs_nyq = self.fs
|
||||
|
||||
# 高通滤波 (去除低频漂移)
|
||||
high = self.lowcut / nyquist
|
||||
if 0 < high < 1:
|
||||
self._filter_cache['highpass'] = signal.butter(2, high, btype='high', output='ba')
|
||||
|
||||
# 低通滤波 (去除高频噪声)
|
||||
low = self.highcut / nyquist
|
||||
if 0 < low < 1:
|
||||
self._filter_cache['lowpass'] = signal.butter(4, low, btype='low', output='ba')
|
||||
|
||||
# 50Hz 陷波滤波 (去除工频干扰)
|
||||
Q = 30 # 品质因子
|
||||
self._filter_cache['notch'] = signal.iirnotch(self.notch_freq, Q, fs=fs_nyq)
|
||||
|
||||
# 查找CPZ通道索引
|
||||
electrode_name = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1',
|
||||
'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6',
|
||||
'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1','T7',
|
||||
'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ','PZ',
|
||||
'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1']
|
||||
try:
|
||||
self.ref_chan_idx = electrode_name.index(self.ref_chan_name)
|
||||
except ValueError:
|
||||
self.ref_chan_idx = 50 # 默认CZ (对应索引50)
|
||||
print(f'[警告] 未找到参考电极 {self.ref_chan_name},使用默认值 CZ')
|
||||
|
||||
self._cache_valid = True
|
||||
print(f'[INFO] 预处理已启用 - 高通:{self.lowcut}Hz, 低通:{self.highcut}Hz, 陷波:{self.notch_freq}Hz, 参考:{self.ref_chan_name}(索引:{self.ref_chan_idx})')
|
||||
|
||||
def check_data_quality(self, data):
|
||||
"""
|
||||
检查数据质量
|
||||
|
||||
返回:
|
||||
list: 发现的问题列表,空列表表示质量正常
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# 检查幅度
|
||||
amplitude = np.max(np.abs(data))
|
||||
if amplitude > 1e6: # 超过 1mV = 1000µV
|
||||
issues.append(f'幅度异常: {amplitude:.2e} (可能为原始ADC值或单位错误)')
|
||||
elif amplitude > 1000:
|
||||
issues.append(f'幅度偏高: {amplitude:.2f}')
|
||||
|
||||
# 检查平坦噪声 (通道可能未连接)
|
||||
if np.std(data) < 0.01:
|
||||
issues.append('信号过平,可能通道未连接')
|
||||
|
||||
# 检查饱和
|
||||
n_saturated = np.sum(np.abs(data) > 1e8)
|
||||
if n_saturated > 0:
|
||||
issues.append(f'检测到 {n_saturated} 个采样点饱和')
|
||||
|
||||
return issues
|
||||
|
||||
def verify_and_calibrate_unit(self, data):
|
||||
"""
|
||||
验证并校准数据单位
|
||||
|
||||
SunnyLinker64 的转换公式:
|
||||
val = raw_adc * 4.5 / gain_value / 8388608 * 1000000 (µV)
|
||||
|
||||
但如果硬件实际增益与 gain_value=6 不符,会导致单位错误。
|
||||
本函数通过检测数据范围来验证和修正单位。
|
||||
|
||||
正常EEG信号范围: ±50-100 µV
|
||||
如果检测到的数据范围是 ±1e6 量级,说明转换可能有问题
|
||||
|
||||
参数:
|
||||
data: 原始数据
|
||||
|
||||
返回:
|
||||
tuple: (校准后的数据, 是否进行了校准)
|
||||
"""
|
||||
if self._conversion_verified:
|
||||
return data, False
|
||||
|
||||
amplitude = np.max(np.abs(data))
|
||||
|
||||
# 判断数据是否在合理范围内
|
||||
# 正常EEG: 1 - 1000 µV (考虑某些高幅值情况)
|
||||
# 异常: > 1e5 µV (可能是ADC原始值未转换或转换系数错误)
|
||||
|
||||
if amplitude > 1e6:
|
||||
print('[警告] 检测到异常大幅值数据,可能是ADC原始值或单位转换失败!')
|
||||
print(f' 当前最大幅度: {amplitude:.2e} µV')
|
||||
print('[INFO] 尝试自动校准单位转换...')
|
||||
|
||||
# SunnyLinker64 的理论转换系数约为 0.0894 µV/LSB
|
||||
# 如果数据是原始ADC值,需要除以这个系数来还原
|
||||
theoretical_scale = 4.5 / 6 / 8388608 * 1e6 # 理论系数: ~0.0894 µV/LSB
|
||||
|
||||
# 计算校准系数
|
||||
# 假设数据是原始ADC值,需要除以 (amplitude / expected_amplitude)
|
||||
# 正常EEG信号预期幅度约 100 µV
|
||||
expected_amplitude = 100.0 # µV
|
||||
|
||||
if amplitude > expected_amplitude:
|
||||
# 计算校准系数: 原始值 / 预期值 = 实际值 / 校准后值
|
||||
self.calibration_scale = expected_amplitude / amplitude
|
||||
|
||||
# 应用校准
|
||||
data = data * self.calibration_scale
|
||||
print(f'[INFO] 校准完成,应用系数: {self.calibration_scale:.6e}')
|
||||
print(f' 校准后最大幅度: {np.max(np.abs(data)):.2f} µV')
|
||||
self._conversion_verified = True
|
||||
return data, True
|
||||
|
||||
elif amplitude < 0.01:
|
||||
print('[警告] 数据幅度接近零,可能通道未连接或设备异常')
|
||||
|
||||
self._conversion_verified = True
|
||||
return data, False
|
||||
|
||||
def preprocess_data(self, data):
|
||||
"""
|
||||
EEG信号预处理
|
||||
|
||||
参数:
|
||||
data: ndarray, shape (n_chan, n_samples), 原始EEG数据
|
||||
|
||||
返回:
|
||||
ndarray: 预处理后的EEG数据
|
||||
"""
|
||||
if not self.enable_preprocess:
|
||||
return data
|
||||
|
||||
# 确保数据是 float64 类型
|
||||
data = data.astype(np.float64)
|
||||
|
||||
# 设计滤波器
|
||||
self._design_filters()
|
||||
|
||||
# 1. 去除直流分量和低频漂移 (高通滤波)
|
||||
if self._filter_cache['highpass'] is not None:
|
||||
b, a = self._filter_cache['highpass']
|
||||
for ch in range(data.shape[0]):
|
||||
data[ch, :] = signal.filtfilt(b, a, data[ch, :])
|
||||
|
||||
# 2. 50Hz 工频陷波滤波
|
||||
if self._filter_cache['notch'] is not None:
|
||||
b, a = self._filter_cache['notch']
|
||||
for ch in range(data.shape[0]):
|
||||
data[ch, :] = signal.filtfilt(b, a, data[ch, :])
|
||||
|
||||
# 3. 低通滤波 (去除高频噪声)
|
||||
if self._filter_cache['lowpass'] is not None:
|
||||
b, a = self._filter_cache['lowpass']
|
||||
for ch in range(data.shape[0]):
|
||||
data[ch, :] = signal.filtfilt(b, a, data[ch, :])
|
||||
|
||||
# 4. 重参考 (以CPZ为参考)
|
||||
if self.ref_chan_idx is not None and self.ref_chan_idx < data.shape[0]:
|
||||
ref_signal = data[self.ref_chan_idx, :]
|
||||
data = data - ref_signal
|
||||
|
||||
return data
|
||||
|
||||
def stop(self):
|
||||
'''
|
||||
停止运行
|
||||
@return:
|
||||
'''
|
||||
self.zmqServer.stop()
|
||||
self.Running=False
|
||||
207
algorithm_V0/datacollect/eegParser_scipy_package.py
Normal file
207
algorithm_V0/datacollect/eegParser_scipy_package.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from SunnyLinker import SunnyLinker64
|
||||
from zmqServer import zmqServer
|
||||
from zmqClient import zmqClient
|
||||
from scipy.io import savemat
|
||||
|
||||
class Parser_main(threading.Thread):
|
||||
def __init__(self):
|
||||
threading.Thread.__init__(self)
|
||||
self.Running = True
|
||||
self.fs = 250 # 采样率
|
||||
self.energy = 0 # 电量
|
||||
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||
self.n_chan = 64
|
||||
self.dataBuffer = []
|
||||
self.file_num = 0#保存文件序号
|
||||
self.subject_id = None #受试者ID
|
||||
self.session_id = None #Session ID
|
||||
self.last_print_time = None
|
||||
|
||||
def connect(self):
|
||||
self.thread_data_server = SunnyLinker64('127.0.0.1', 7878, 250, 64,
|
||||
method='tcp')
|
||||
self.thread_data_server.toUv = True
|
||||
self.thread_data_server.start()
|
||||
|
||||
self.zmqServer = zmqServer()
|
||||
self.zmqServer.start()
|
||||
self.zmqClient = zmqClient('127.0.0.1', 8088)
|
||||
self.zmqClient.connect()
|
||||
|
||||
def run(self):
|
||||
while self.Running:
|
||||
# 同步信息
|
||||
if self.zmqServer.state_mode == 'sync':
|
||||
self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||
self.zmqServer.state_mode = 'rest'
|
||||
# 状态异常,报告上位机
|
||||
if self.status_code != self.thread_data_server.status_code:
|
||||
self.status_code = self.thread_data_server.status_code
|
||||
self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||
|
||||
# 返回电量
|
||||
if self.energy != self.thread_data_server.energy:
|
||||
self.energy = self.thread_data_server.energy
|
||||
self.zmqClient.send_to_all('energy', int(self.energy))
|
||||
|
||||
# 更新文件序号
|
||||
if self.subject_id != self.zmqServer.subject_id or self.session_id != self.zmqServer.session_id:
|
||||
self.subject_id = self.zmqServer.subject_id
|
||||
self.session_id = self.zmqServer.session_id
|
||||
self.file_num = 0 #从零开始计数
|
||||
|
||||
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||
self.thread_data_server.Impedance(True)
|
||||
self.zmqServer.open_Impedance = -1
|
||||
elif self.zmqServer.open_Impedance == False:
|
||||
self.thread_data_server.Impedance(False)
|
||||
self.zmqServer.open_Impedance = -1
|
||||
|
||||
if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||
if self.thread_data_server.GetDataLenCount() > self.fs:
|
||||
Impe_data = self.thread_data_server.getData(self.fs)
|
||||
# 计算阻抗
|
||||
imps = self.thread_data_server.getImpedance(Impe_data, self.n_chan)
|
||||
self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||
else:
|
||||
pass
|
||||
if self.thread_data_server.GetDataLenCount() < 50:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||
data = self.thread_data_server.getData(50)
|
||||
data = data[:self.n_chan, :]
|
||||
if self.zmqServer.mat_generate:
|
||||
# 检测是否需要重置缓冲区(第二次发送 matGenerate 时清空旧数据)
|
||||
if self.zmqServer.reset_mat_buffer:
|
||||
self.dataBuffer = []
|
||||
self.last_print_time = None
|
||||
self.zmqServer.reset_mat_buffer = False
|
||||
print('[INFO] 数据缓冲区已重置,从头开始采集')
|
||||
|
||||
self.dataBuffer.append(data)
|
||||
if len(self.dataBuffer) % 50 == 0:
|
||||
current_time = time.time()
|
||||
if self.last_print_time is not None:
|
||||
elapsed_time = current_time - self.last_print_time
|
||||
# 2500个点 = 50个数据块 * 50个采样点/数据块
|
||||
actual_fs = 2500 / elapsed_time
|
||||
print(f"接收 2500 个采样点耗时: {elapsed_time:.4f} 秒, 折合实际采样率: {actual_fs:.2f} Hz")
|
||||
else:
|
||||
print("开始计时...")
|
||||
self.last_print_time = current_time
|
||||
print('数据保存进度: {}/{}'.format(len(self.dataBuffer),int(self.zmqServer.save_win*self.fs//50)))
|
||||
if len(self.dataBuffer) >= int(self.zmqServer.save_win*self.fs//50): #5分钟*60秒*250Hz / 50
|
||||
self.zmqServer.mat_generate = False
|
||||
matData = np.hstack(self.dataBuffer[:int(self.zmqServer.save_win*self.fs//50)])
|
||||
self.dataBuffer = []
|
||||
self.last_print_time = None # 重置计时器以备下次使用
|
||||
self.pack2mat(matData,self.subject_id,self.session_id)
|
||||
|
||||
def pack2mat(self,data,subject_id,session_id):
|
||||
#EEG数据
|
||||
Data = data.T
|
||||
#通道名称
|
||||
channel_names = np.array(
|
||||
['AIN1', 'AIN2', 'AIN3', 'AIN4', 'AIN5', 'AIN6', 'AIN7', 'AIN8', 'AIN9', 'AIN10', 'AIN11', 'AIN12',
|
||||
'AIN13', 'AIN14', 'AIN15', 'AIN16', 'AIN17', 'AIN18', 'AIN19', 'AIN20', 'AIN21', 'AIN22', 'AIN23',
|
||||
'AIN24', 'AIN25', 'AIN26', 'AIN27', 'AIN28', 'AIN29', 'AIN30', 'AIN31', 'AIN32', 'AIN33', 'AIN34',
|
||||
'AIN35', 'AIN36', 'AIN37', 'AIN38', 'AIN39', 'AIN40', 'AIN41', 'AIN42', 'AIN43', 'AIN44', 'AIN45',
|
||||
'AIN46', 'AIN47', 'AIN48', 'AIN49', 'AIN50', 'AIN51', 'AIN52', 'AIN53', 'AIN54', 'AIN55', 'AIN56',
|
||||
'AIN57', 'AIN58', 'AIN59', 'AIN60', 'AIN61', 'AIN62', 'AIN63', 'AIN64'], dtype=object)
|
||||
#采样率
|
||||
sample_rate = self.fs
|
||||
#通道数量
|
||||
node_number = Data.shape[1]
|
||||
# 时间轴
|
||||
t = np.linspace(0, self.zmqServer.save_win, Data.shape[0])
|
||||
t = t.reshape(len(t), 1)
|
||||
#电极名称
|
||||
electrode_name = np.array(['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1',
|
||||
'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6',
|
||||
'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1','T7',
|
||||
'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ','PZ',
|
||||
'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1'],
|
||||
dtype=object)
|
||||
#电极三维坐标
|
||||
electrode_xyz = self.read_ch_pos()
|
||||
electrode_xyz.update({'A1': [-0.095, 0, -0.005]})
|
||||
electrode_xyz = {key: electrode_xyz[key] for key in electrode_name}
|
||||
electrode_xyz = np.array(list(electrode_xyz.values()))
|
||||
#电极坐标所属的坐标系
|
||||
electrode_coord_system = '10-20 spherical model'
|
||||
#受试者ID
|
||||
Subject_id = subject_id
|
||||
#Session ID
|
||||
Session_id = session_id
|
||||
#参考电极方案
|
||||
ref = 'CPZ'
|
||||
#数据采集开始时间
|
||||
start_time = 0
|
||||
|
||||
meta_struct = {
|
||||
'subject_id': Subject_id,
|
||||
'session_id': Session_id,
|
||||
'ref': ref,
|
||||
'start_time': start_time
|
||||
}
|
||||
|
||||
eeg_struct = {
|
||||
'data': Data,
|
||||
'chn': channel_names,
|
||||
'sample_rate': sample_rate,
|
||||
'node_number': node_number,
|
||||
't': t,
|
||||
'electrode_name': electrode_name,
|
||||
'electrode_xyz': electrode_xyz,
|
||||
'electrode_coord_system': electrode_coord_system,
|
||||
'meta': meta_struct,
|
||||
}
|
||||
|
||||
fileDir = os.path.join('EEGfiles/',Subject_id,Session_id)
|
||||
os.makedirs(fileDir,exist_ok=True)
|
||||
filePath = os.path.join(fileDir,'eeg_data{}.mat'.format(self.file_num))
|
||||
# 保存到 .mat 文件,顶层变量名为 'eeg'
|
||||
savemat(filePath, {'eeg': eeg_struct})
|
||||
print('EEGfile saved at {}'.format(filePath))
|
||||
self.zmqClient.send_to_all('filePath', filePath)
|
||||
self.file_num += 1
|
||||
|
||||
def read_ch_pos(self,file_path=r'xy_64.xlsx'):
|
||||
"""
|
||||
将电极位置信息转换为Dict
|
||||
|
||||
参数:
|
||||
file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'列
|
||||
|
||||
"""
|
||||
if getattr(sys, 'frozen', False):
|
||||
script_dir = sys._MEIPASS
|
||||
else:
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(script_dir, file_path)
|
||||
df = pd.read_excel(file_path)
|
||||
# 确保列名正确
|
||||
if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']):
|
||||
raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'列")
|
||||
# 创建电极位置字典
|
||||
ch_pos = {}
|
||||
for _, row in df.iterrows():
|
||||
ch_pos[row['channel']] = [row['x'], row['y'], row['z']]
|
||||
return ch_pos
|
||||
|
||||
def stop(self):
|
||||
'''
|
||||
停止运行
|
||||
@return:
|
||||
'''
|
||||
self.zmqServer.stop()
|
||||
self.Running=False
|
||||
185
algorithm_V0/datacollect/eeg_quality_check-mat.py
Normal file
185
algorithm_V0/datacollect/eeg_quality_check-mat.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
EEG Data Quality Check - eeg_data0.mat
|
||||
===================================
|
||||
1. Time Domain Signal (Full Duration)
|
||||
2. Amplitude Spectrum (FFT)
|
||||
3. Power Spectral Density (Linear Scale)
|
||||
4. Power Spectral Density (dB Scale)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import mne
|
||||
from scipy import signal
|
||||
from scipy.io import loadmat
|
||||
|
||||
|
||||
def load_and_preprocess(filepath):
|
||||
"""Load .mat file (custom format) and basic preprocessing."""
|
||||
mat_data = loadmat(filepath, simplify_cells=True)
|
||||
eeg = mat_data['eeg']
|
||||
|
||||
# Extract data (shape: samples x channels)
|
||||
data = eeg['data'].T # Transpose to (channels x samples)
|
||||
sfreq = eeg['sample_rate']
|
||||
|
||||
# Get channel names (try multiple possible keys)
|
||||
if 'chn' in eeg:
|
||||
ch_names = list(eeg['chn'])
|
||||
elif 'electrode_name' in eeg:
|
||||
ch_names = list(eeg['electrode_name'])
|
||||
else:
|
||||
n_channels = data.shape[0]
|
||||
ch_names = [f'Ch{i+1}' for i in range(n_channels)]
|
||||
|
||||
# Create MNE Info object
|
||||
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
|
||||
raw = mne.io.RawArray(data, info)
|
||||
|
||||
raw.filter(l_freq=0.5, h_freq=10, fir_design='firwin', verbose=False)
|
||||
return raw
|
||||
|
||||
|
||||
def main():
|
||||
filepath = r"D:\Ivey\Code_New_Proj\brainplot\plot64\eeg_data0511.mat"
|
||||
output_path = r"D:\Ivey\Code_New_Proj\brainplot\plot64\eeg_quality_check_depression.png"
|
||||
raw = load_and_preprocess(filepath)
|
||||
|
||||
# Print all channel names first
|
||||
print(f"\nAvailable channels ({len(raw.ch_names)}):")
|
||||
for i, ch in enumerate(raw.ch_names):
|
||||
print(f" {i:3d}: {ch}")
|
||||
|
||||
select_channel = ['AIN5']
|
||||
raw.pick(select_channel)
|
||||
|
||||
# Use all channels, full duration
|
||||
ch_names = raw.ch_names
|
||||
n_channels = len(ch_names)
|
||||
data = raw.get_data()
|
||||
sfreq = raw.info['sfreq']
|
||||
n_samples = data.shape[1]
|
||||
duration = n_samples / sfreq
|
||||
|
||||
print(f"Info: {n_channels} channels, {duration:.1f}s, {sfreq:.0f} Hz")
|
||||
|
||||
# Compute frequency domain data
|
||||
n_fft = 2**int(np.ceil(np.log2(n_samples)))
|
||||
freqs_fft = np.fft.rfftfreq(n_fft, 1 / sfreq)
|
||||
fft_vals = np.fft.rfft(data, n=n_fft)
|
||||
amplitude = np.abs(fft_vals) / n_fft * 2
|
||||
|
||||
freqs_psd, psd = signal.welch(data, fs=sfreq, nperseg=4096,
|
||||
noverlap=2048, scaling='density')
|
||||
|
||||
# Frequency mask: 0.5-80 Hz
|
||||
mask_fft = (freqs_fft >= 0.5) & (freqs_fft <= 80)
|
||||
mask_psd = (freqs_psd >= 0.5) & (freqs_psd <= 80)
|
||||
|
||||
freq_fft = freqs_fft[mask_fft]
|
||||
freq_psd = freqs_psd[mask_psd]
|
||||
|
||||
# Plot: 4 rows x 1 column
|
||||
fig, axes = plt.subplots(4, 1, figsize=(16, 20))
|
||||
fig.suptitle(f'EEG Data Quality Check — {", ".join(ch_names)}, '
|
||||
f'Full Duration: {duration:.1f}s',
|
||||
fontsize=16, fontweight='bold', y=0.995)
|
||||
|
||||
# Colormap for distinct channel
|
||||
cmap = plt.cm.tab10 if n_channels <= 10 else plt.cm.tab20
|
||||
colors = [cmap(i) for i in np.linspace(0, 1, n_channels)]
|
||||
|
||||
# ---- Row 1: Time Domain Signal ----
|
||||
ax = axes[0]
|
||||
offset = 0
|
||||
step = max(100, np.std(data, axis=1).mean() * 1e6 * 4)
|
||||
|
||||
# Downsample for display
|
||||
ds = max(1, n_samples // (int(duration) * 500))
|
||||
t = np.arange(0, n_samples, ds) / sfreq
|
||||
|
||||
for i in range(n_channels):
|
||||
sig = data[i, ::ds] * 1e6 + offset
|
||||
ax.plot(t, sig, linewidth=0.5, alpha=0.9, color=colors[i], label=ch_names[i])
|
||||
ax.text(t[0] - 0.5, offset, ch_names[i], fontsize=7, va='center', ha='right', color=colors[i])
|
||||
offset += step
|
||||
|
||||
ax.set_xlim(0, duration)
|
||||
ax.set_xlabel('Time (s)')
|
||||
ax.set_ylabel('Amplitude (μV)')
|
||||
ax.set_title('1. Time Domain Signal (Full Duration)', fontweight='bold')
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
|
||||
|
||||
# ---- Row 2: Amplitude Spectrum (FFT) ----
|
||||
ax = axes[1]
|
||||
amp_data = amplitude[:, mask_fft] * 1e6 # (n_channels, n_freqs)
|
||||
for i in range(n_channels):
|
||||
ax.plot(freq_fft, amp_data[i], color=colors[i], linewidth=1.0, alpha=0.85, label=ch_names[i])
|
||||
ax.axvline(50, color='red', linestyle='--', alpha=0.6, label='50 Hz Mains')
|
||||
ax.set_xlim(0.5, 30)
|
||||
ax.set_xlabel('Frequency (Hz)')
|
||||
ax.set_ylabel('Amplitude (μV)')
|
||||
ax.set_title('2. Amplitude Spectrum (FFT)', fontweight='bold')
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
|
||||
|
||||
# ---- Row 3: PSD (Linear Scale) ----
|
||||
ax = axes[2]
|
||||
psd_data = psd[:, mask_psd] * 1e12 # (n_channels, n_freqs)
|
||||
for i in range(n_channels):
|
||||
ax.plot(freq_psd, psd_data[i], color=colors[i], linewidth=1.0, alpha=0.85, label=ch_names[i])
|
||||
ax.axvline(50, color='red', linestyle='--', alpha=0.6, label='50 Hz Mains')
|
||||
ax.set_xlim(0.5, 80)
|
||||
ax.set_xlabel('Frequency (Hz)')
|
||||
ax.set_ylabel('Power (μV²/Hz)')
|
||||
ax.set_title('3. Power Spectral Density (Linear Scale)', fontweight='bold')
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
|
||||
|
||||
# ---- Row 4: PSD (dB Scale) ----
|
||||
ax = axes[3]
|
||||
for i in range(n_channels):
|
||||
psd_dbi = 10 * np.log10(psd_data[i] + 1e-20)
|
||||
ax.plot(freq_psd, psd_dbi, color=colors[i], linewidth=1.0, alpha=0.85, label=ch_names[i])
|
||||
ax.axvline(50, color='red', linestyle='--', alpha=0.6, label='50 Hz Mains')
|
||||
ax.set_xlim(0.5, 80)
|
||||
ax.set_xlabel('Frequency (Hz)')
|
||||
ax.set_ylabel('Power (dB)')
|
||||
ax.set_title('4. Power Spectral Density (dB Scale)', fontweight='bold')
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.subplots_adjust(top=0.97)
|
||||
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight',
|
||||
facecolor='white', edgecolor='none')
|
||||
print(f"Figure saved to: {output_path}")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
from scipy.io import loadmat
|
||||
import numpy as np
|
||||
|
||||
mat = loadmat(r'D:\Ivey\Code_New_Proj\brainplot\plot64\eeg_data0511.mat', simplify_cells=True)
|
||||
data = mat['eeg']['data'] # (samples, channels)
|
||||
sfreq = 250
|
||||
seg1 = data[0:int(10*sfreq), :] # 0-10s
|
||||
seg2 = data[int(10*sfreq):int(20*sfreq), :] # 10-20s
|
||||
|
||||
print('Segment 1 (0-10s) shape:', seg1.shape)
|
||||
print('Segment 2 (10-20s) shape:', seg2.shape)
|
||||
print('Are they equal?', np.allclose(seg1, seg2))
|
||||
print('Max difference:', np.max(np.abs(seg1 - seg2)))
|
||||
print('Mean difference:', np.mean(np.abs(seg1 - seg2)))
|
||||
|
||||
# Check correlation
|
||||
corr = np.corrcoef(seg1.flatten(), seg2.flatten())[0, 1]
|
||||
print(f'Correlation: {corr:.4f}')
|
||||
BIN
algorithm_V0/datacollect/eeg_quality_check_depression.png
Normal file
BIN
algorithm_V0/datacollect/eeg_quality_check_depression.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 521 KiB |
193
algorithm_V0/datacollect/protocol.py
Normal file
193
algorithm_V0/datacollect/protocol.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from typing import List, Tuple, Union, Optional
|
||||
|
||||
|
||||
class ProtocolFrame:
|
||||
# 协议常量
|
||||
FRAME_HEADER = 0xAA
|
||||
FRAME_TAIL1 = 0x55
|
||||
FRAME_TAIL2 = 0x55
|
||||
RESERVED_SIZE = 6
|
||||
MIN_FRAME_SIZE = 13 # 帧头1 + 功能1 + 长度2 + 预留6 + CRC1 + 包尾2
|
||||
MAX_DATA_LENGTH = 0xFFFF # 最大数据长度 (2字节能表示的最大值)
|
||||
|
||||
@staticmethod
|
||||
def calculate_crc8(data: bytes) -> bytes:
|
||||
"""
|
||||
计算CRC8校验值
|
||||
Args:
|
||||
data: 需要计算CRC的数据
|
||||
Returns:
|
||||
一个字节的CRC值(bytes类型)
|
||||
"""
|
||||
crc = 0
|
||||
for byte in data:
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
crc = ((crc << 1) ^ 0x07 if crc & 0x80 else crc << 1) & 0xFF
|
||||
return bytes([crc])
|
||||
|
||||
@classmethod
|
||||
def pack(cls, function, data: Union[bytes, bytearray, List[int]],
|
||||
reserved: Optional[Union[bytes, bytearray, List[int]]] = None) -> bytes:
|
||||
"""
|
||||
协议打包函数
|
||||
|
||||
Args:
|
||||
function: 功能码 (1字节)
|
||||
data: 数据块
|
||||
reserved: 预留字节(6字节,可选)
|
||||
|
||||
Returns:
|
||||
打包后的字节数据
|
||||
"""
|
||||
# 检查功能码
|
||||
if function != None:
|
||||
if not 0 <= function <= 0xFF:
|
||||
raise ValueError("功能码必须是1字节")
|
||||
|
||||
# 转换数据为bytearray
|
||||
if isinstance(data, list):
|
||||
data = bytearray(data)
|
||||
elif isinstance(data, bytes):
|
||||
data = bytearray(data)
|
||||
|
||||
# 检查数据长度
|
||||
data_length = len(data)
|
||||
if data_length > cls.MAX_DATA_LENGTH:
|
||||
raise ValueError(f"数据长度超过最大值 {cls.MAX_DATA_LENGTH}")
|
||||
|
||||
# 处理预留字节
|
||||
if reserved is None:
|
||||
reserved = bytearray([0] * cls.RESERVED_SIZE)
|
||||
else:
|
||||
if isinstance(reserved, list):
|
||||
reserved = bytearray(reserved)
|
||||
elif isinstance(reserved, bytes):
|
||||
reserved = bytearray(reserved)
|
||||
if len(reserved) != cls.RESERVED_SIZE:
|
||||
raise ValueError(f"预留字节必须是{cls.RESERVED_SIZE}字节")
|
||||
|
||||
# 构建帧
|
||||
frame = bytearray([cls.FRAME_HEADER]) # 帧头 (1字节)
|
||||
if function != None:
|
||||
frame.append(function) # 功能码 (1字节)
|
||||
data_length+=6
|
||||
|
||||
# 数据长度 (2字节,大端序)
|
||||
frame.append((data_length >> 8) & 0xFF) # 高字节
|
||||
frame.append(data_length & 0xFF) # 低字节
|
||||
|
||||
if function != None:
|
||||
frame.extend(reserved) # 预留字节 (6字节)
|
||||
frame.extend(data) # 数据块 (变长)
|
||||
|
||||
# 计算CRC (从功能码开始到数据块结束)
|
||||
crc = cls.calculate_crc8(frame[1:]) # 不包含帧头
|
||||
frame.extend(crc) # CRC校验 (1字节)
|
||||
|
||||
# 添加帧尾
|
||||
frame.extend([cls.FRAME_TAIL1, cls.FRAME_TAIL2]) # 帧尾 (2字节)
|
||||
|
||||
return bytes(frame)
|
||||
|
||||
@classmethod
|
||||
def unpack(cls, data: Union[bytes, bytearray]) -> Tuple[int, bytearray, bytearray]:
|
||||
"""
|
||||
协议解包函数
|
||||
|
||||
Args:
|
||||
data: 待解析的字节数据
|
||||
|
||||
Returns:
|
||||
(功能码, 数据块, 预留字节)
|
||||
|
||||
Raises:
|
||||
ValueError: 当数据格式不正确时
|
||||
"""
|
||||
# 检查数据长度
|
||||
if len(data) < cls.MIN_FRAME_SIZE:
|
||||
raise ValueError("数据长度不足")
|
||||
|
||||
# 检查帧头
|
||||
if data[0] != cls.FRAME_HEADER:
|
||||
raise ValueError("帧头错误")
|
||||
|
||||
# 检查帧尾
|
||||
if data[-2:] != bytes([cls.FRAME_TAIL1, cls.FRAME_TAIL2]):
|
||||
raise ValueError("帧尾错误")
|
||||
|
||||
# 解析基本信息
|
||||
function = data[1] # 功能码 (1字节)
|
||||
|
||||
# 数据长度 (2字节,大端序)
|
||||
data_length = (data[2] << 8) | data[3]
|
||||
|
||||
reserved = data[4:10] # 预留字节 (6字节)
|
||||
|
||||
# 检查数据长度
|
||||
expected_length = cls.MIN_FRAME_SIZE + data_length
|
||||
if len(data) != expected_length:
|
||||
raise ValueError(f"数据长度不匹配: 期望{expected_length}字节,实际{len(data)}字节")
|
||||
|
||||
# 提取数据块
|
||||
payload = data[10:10 + data_length]
|
||||
|
||||
# 验证CRC (从功能码开始到数据块结束)
|
||||
received_crc = data[-3]
|
||||
calculated_crc = cls.calculate_crc8(data[1:-3])[0] # 获取字节值
|
||||
|
||||
if received_crc != calculated_crc:
|
||||
raise ValueError(f"CRC校验失败: 期望{calculated_crc:02X},实际{received_crc:02X}")
|
||||
|
||||
return function, bytearray(payload), bytearray(reserved)
|
||||
|
||||
|
||||
|
||||
def print_hex(data: bytes, label: str = ""):
|
||||
"""打印十六进制数据,并按字节添加空格"""
|
||||
hex_str = ' '.join([f"{b:02X}" for b in data])
|
||||
if label:
|
||||
print(f"{label}: {hex_str}")
|
||||
else:
|
||||
print(hex_str)
|
||||
|
||||
|
||||
def print_frame_details(data: bytes):
|
||||
"""打印帧的详细信息"""
|
||||
print("帧详细信息:")
|
||||
print(f"帧头: {data[0]:02X}")
|
||||
print(f"功能码: {data[1]:02X}")
|
||||
print(f"数据长度: {data[2]:02X} {data[3]:02X} ({(data[2] << 8) | data[3]}字节)")
|
||||
print(f"预留字节: {' '.join([f'{b:02X}' for b in data[4:10]])}")
|
||||
data_length = (data[2] << 8) | data[3]
|
||||
print(f"数据块: {' '.join([f'{b:02X}' for b in data[10:10 + data_length]])}")
|
||||
print(f"CRC校验: {data[-3]:02X}")
|
||||
print(f"帧尾: {data[-2]:02X} {data[-1]:02X}")
|
||||
|
||||
|
||||
# 使用示例
|
||||
def example_usage():
|
||||
try:
|
||||
|
||||
|
||||
# 示例1:简单数据打包
|
||||
function_code = 0x01
|
||||
data = [0x1]
|
||||
packed_data = ProtocolFrame.pack(function_code, data)
|
||||
print_hex(packed_data, "示例1 - 完整帧")
|
||||
print_frame_details(packed_data)
|
||||
print()
|
||||
|
||||
# 示例3:解包验证
|
||||
function, payload, reserved = ProtocolFrame.unpack(packed_data)
|
||||
print("解包结果:")
|
||||
print(f"功能码: 0x{function:02X}")
|
||||
print_hex(payload, "数据块")
|
||||
print_hex(reserved, "预留字节")
|
||||
|
||||
except ValueError as e:
|
||||
print(f"错误: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
17
algorithm_V0/datacollect/start_parse.py
Normal file
17
algorithm_V0/datacollect/start_parse.py
Normal file
@@ -0,0 +1,17 @@
|
||||
|
||||
import time
|
||||
from eegParser import Parser_main
|
||||
from RunOnce import is_program_running
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if not is_program_running():
|
||||
parser_ = Parser_main()
|
||||
parser_.connect()
|
||||
|
||||
try:
|
||||
parser_.start()
|
||||
while not parser_.zmqServer.IsExitApp:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
parser_.stop()
|
||||
137
algorithm_V0/datacollect/verify_build.py
Normal file
137
algorithm_V0/datacollect/verify_build.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
PyInstaller 打包验证脚本
|
||||
用于在没有 EEG 设备的情况下验证打包是否成功
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
def check_pyinstaller_installed():
|
||||
"""检查 PyInstaller 是否安装"""
|
||||
try:
|
||||
result = subprocess.run(['pyinstaller', '--version'],
|
||||
capture_output=True, text=True)
|
||||
print(f"✓ PyInstaller 版本: {result.stdout.strip()}")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
print("✗ PyInstaller 未安装")
|
||||
return False
|
||||
|
||||
def check_dist_folder():
|
||||
"""检查 dist 文件夹是否存在"""
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dist_dir = os.path.join(base_dir, 'dist', 'start_parse')
|
||||
|
||||
if os.path.exists(dist_dir):
|
||||
print(f"✓ dist 文件夹存在: {dist_dir}")
|
||||
|
||||
# 检查 exe 文件
|
||||
exe_path = os.path.join(dist_dir, 'start_parse.exe')
|
||||
if os.path.exists(exe_path):
|
||||
size_mb = os.path.getsize(exe_path) / (1024 * 1024)
|
||||
print(f"✓ 可执行文件存在: start_parse.exe ({size_mb:.1f} MB)")
|
||||
else:
|
||||
print("✗ 可执行文件不存在")
|
||||
return False
|
||||
|
||||
# 检查资源文件
|
||||
xlsx_path = os.path.join(dist_dir, 'xy_64.xlsx')
|
||||
if os.path.exists(xlsx_path):
|
||||
print(f"✓ 资源文件存在: xy_64.xlsx")
|
||||
else:
|
||||
print("✗ 资源文件 xy_64.xlsx 不存在")
|
||||
|
||||
return True
|
||||
else:
|
||||
print(f"✗ dist 文件夹不存在,请先运行打包")
|
||||
return False
|
||||
|
||||
def check_dependencies():
|
||||
"""检查关键依赖是否在打包中"""
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dist_dir = os.path.join(base_dir, 'dist', 'start_parse')
|
||||
|
||||
if not os.path.exists(dist_dir):
|
||||
return False
|
||||
|
||||
# 检查关键 DLL 文件
|
||||
critical_dlls = [
|
||||
# zmq 依赖
|
||||
'libzmq.pyd',
|
||||
# numpy 依赖
|
||||
'numpy.core._multiarray_umath.cp310-win_amd64.pyd',
|
||||
# scipy 依赖
|
||||
'scipy.special._ufuncs.cp310-win_amd64.pyd',
|
||||
]
|
||||
|
||||
print("\n检查关键依赖文件:")
|
||||
found_count = 0
|
||||
for dll in critical_dlls:
|
||||
found = False
|
||||
for root, dirs, files in os.walk(dist_dir):
|
||||
if dll in files:
|
||||
found = True
|
||||
break
|
||||
status = "✓" if found else "✗"
|
||||
print(f" {status} {dll}")
|
||||
if found:
|
||||
found_count += 1
|
||||
|
||||
return found_count >= len(critical_dlls) // 2
|
||||
|
||||
def test_imports():
|
||||
"""测试关键模块是否可以导入"""
|
||||
print("\n测试模块导入:")
|
||||
|
||||
modules = ['zmq', 'serial', 'numpy', 'pandas', 'scipy']
|
||||
success = True
|
||||
|
||||
for mod in modules:
|
||||
try:
|
||||
__import__(mod)
|
||||
print(f" ✓ {mod}")
|
||||
except ImportError as e:
|
||||
print(f" ✗ {mod}: {e}")
|
||||
success = False
|
||||
|
||||
return success
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("PyInstaller 打包验证")
|
||||
print("=" * 60)
|
||||
|
||||
checks = [
|
||||
("1. 检查 PyInstaller 安装", check_pyinstaller_installed),
|
||||
("2. 检查 dist 文件夹", check_dist_folder),
|
||||
("3. 检查依赖文件", check_dependencies),
|
||||
("4. 测试模块导入", test_imports),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, check_func in checks:
|
||||
print(f"\n{name}")
|
||||
print("-" * 40)
|
||||
results.append(check_func())
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("验证结果汇总:")
|
||||
print("=" * 60)
|
||||
|
||||
all_passed = all(results)
|
||||
if all_passed:
|
||||
print("✓ 所有检查通过!打包成功。")
|
||||
print("\n下一步:")
|
||||
print(" 1. 将 dist/start_parse 文件夹复制到目标电脑")
|
||||
print(" 2. 连接 EEG 设备并运行 start_parse.exe")
|
||||
print(" 3. 观察控制台输出是否正常")
|
||||
else:
|
||||
print("✗ 部分检查未通过,请查看上述详细信息")
|
||||
|
||||
return all_passed
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
algorithm_V0/datacollect/zmqClient.py
Normal file
57
algorithm_V0/datacollect/zmqClient.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
import zmq
|
||||
|
||||
|
||||
class zmqClient:
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.client_socket = None
|
||||
self.running = False
|
||||
|
||||
# 记录客户端连接前的状态
|
||||
self.state = {
|
||||
'status_code': None,
|
||||
'energy': None
|
||||
}
|
||||
|
||||
def connect(self):
|
||||
# 创建 ZeroMQ 上下文
|
||||
self.context = zmq.Context()
|
||||
# 创建 REQ 套接字(请求端)
|
||||
self.client_socket = self.context.socket(zmq.DEALER)
|
||||
# client_id = b'client1'
|
||||
# self.client_socket.setsockopt(zmq.IDENTITY,client_id)
|
||||
self.client_socket.connect(f"tcp://{self.host}:{self.port}") # 连接到服务器
|
||||
self.running = True
|
||||
|
||||
def send_to_all(self, method,params):
|
||||
if method in self.state.keys():
|
||||
self.state[method] = params
|
||||
try:
|
||||
if self.running and self.client_socket != None:
|
||||
msg = {'method': method, 'params': params}
|
||||
# 发送响应
|
||||
# print(msg)
|
||||
self.client_socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||
else:
|
||||
if method in self.state.keys():
|
||||
self.state[method] = params
|
||||
except ConnectionResetError:
|
||||
print("Connection lost.")
|
||||
self.running = False
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
def close_connection(self):
|
||||
self.running = False
|
||||
self.client_socket.close()
|
||||
self.context.term()
|
||||
print("Client closed explicitly.")
|
||||
# 使用TCP客户端
|
||||
if __name__ == "__main__":
|
||||
client = zmqClient('127.0.0.1', 8099)
|
||||
client.connect()
|
||||
# client.close_connection()
|
||||
119
algorithm_V0/datacollect/zmqServer.py
Normal file
119
algorithm_V0/datacollect/zmqServer.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import numpy as np
|
||||
import zmq
|
||||
import threading
|
||||
import json
|
||||
from SunnyLinker import SunnyLinker64
|
||||
|
||||
class zmqServer(threading.Thread):
|
||||
def __init__(self, host='0.0.0.0', port=8099):
|
||||
threading.Thread.__init__(self)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.running = False
|
||||
self.get_Impedance = False # 是否返回阻抗值
|
||||
self.open_Impedance = None # 是否开启阻抗检测功能
|
||||
self.StartDecode = False # false 停止解码,true=开始解码
|
||||
self.StartTrain = False # False未进入训练状态,True处于训练状态
|
||||
self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态
|
||||
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
|
||||
self.IsExitApp = False # 当socket收到2的时候,就置为True,代表遥退出系统了。
|
||||
self.getReport = False # 获取训练报告内容
|
||||
self.mat_generate = False # 保存mat文件,True开始,False暂停
|
||||
self.reset_mat_buffer = False # 重置缓冲区标志,True表示下次开始采集需要清空旧数据
|
||||
self.subject_id = None #受试者ID
|
||||
self.session_id = None #Session ID
|
||||
self.save_win = 0 #保存数据时长
|
||||
|
||||
self.daemon = True
|
||||
# 创建 ZeroMQ 上下文
|
||||
self.context = zmq.Context()
|
||||
# 创建 REP 套接字(响应端)
|
||||
self.socket = self.context.socket(zmq.ROUTER)
|
||||
self.socket.bind(f"tcp://{self.host}:{self.port}") # 绑定到端口 8099
|
||||
self.targetFreqs = []
|
||||
self.changeTarget = False # 更换目标频率
|
||||
self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化
|
||||
self.labels = [0x01, 0x02,0x03]
|
||||
|
||||
self.decoder_switch = False #更换解码器
|
||||
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
|
||||
def run(self):
|
||||
self.running = True
|
||||
print(f"Server is running on {self.host}:{self.port}")
|
||||
try:
|
||||
while self.running:
|
||||
# 等待客户端请求
|
||||
_,_,message = self.socket.recv_multipart()
|
||||
message = json.loads(message.decode('utf-8'))
|
||||
print(f"Received request: {message}")
|
||||
# 处理请求
|
||||
method = message.get("method")
|
||||
params = message.get("params")
|
||||
if method == "sync":
|
||||
self.state_mode = 'sync'
|
||||
if method == "targetFreqs":
|
||||
if not isinstance(params,list):
|
||||
print('targetFreqs must be a list')
|
||||
continue
|
||||
if params != self.targetFreqs:
|
||||
self.targetFreqs = params
|
||||
self.changeTarget = True
|
||||
if method == "decoderClass":
|
||||
if not isinstance(params,str):
|
||||
print('decoderClass must be a str')
|
||||
continue
|
||||
# if params != self.decoder_class:
|
||||
self.decoder_class = params
|
||||
self.decoder_switch = True
|
||||
if method == "getReport":
|
||||
self.getReport = True
|
||||
if method == "train":#训练状态
|
||||
self.state_mode = 'train'
|
||||
self.StartTrain = True
|
||||
self.currentLabel = params # 当前刺激端的训练标签
|
||||
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||
elif method == "predict":#预测状态
|
||||
self.state_mode = 'predict'
|
||||
if params == 1: #开始解码
|
||||
self.StartDecode = True
|
||||
self.sunnyLinker.push_trigger(0x63)
|
||||
elif params == 2: #停止解码
|
||||
self.IsExitApp = True
|
||||
self.running = False
|
||||
elif method == "rest": #休息状态
|
||||
self.state_mode = 'rest'
|
||||
elif method == "impedance":
|
||||
if params == 1:
|
||||
self.open_Impedance = True # 开启阻抗
|
||||
self.get_Impedance = True # 返回阻抗
|
||||
elif params == 2:
|
||||
self.open_Impedance = False # 关闭阻抗
|
||||
self.get_Impedance = False # 停止返回阻抗
|
||||
elif method == "matGenerate":
|
||||
self.subject_id = str(params['subject_id'])
|
||||
self.session_id = str(params['session_id'])
|
||||
self.save_win = int(params['time'])
|
||||
self.mat_generate = True
|
||||
self.reset_mat_buffer = True # 每次发送 matGenerate 都重置缓冲区
|
||||
elif method == "stop": # 停止
|
||||
self.mat_generate = False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"An socket error occurred: {e}")
|
||||
finally:
|
||||
self.running = False
|
||||
# 关闭套接字和上下文
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
print("Server socket and context closed.")
|
||||
def stop(self):
|
||||
"""显式关闭服务器"""
|
||||
self.running = False
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
print("Server closed explicitly.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
server = zmqServer()
|
||||
server.start()
|
||||
Reference in New Issue
Block a user