Files
Depression_TMS/algorithm_V0/datacollect/SunnyLinker.py

380 lines
14 KiB
Python
Raw Normal View History

2026-06-01 13:18:36 +08:00
# -*-coding:utf-8 -*-
'''
SunnyLinker的通讯驱动
'''
import ast
import socket
import threading
import time
import datetime
from typing import Dict
import numpy as np
from threading import Thread, Event
import serial
from scipy import signal
from serial.serialutil import SerialException
from protocol import ProtocolFrame
class RingBuffer:
def __init__(self, n_chan, n_points):
self.n_chan = n_chan
self.n_points = n_points
self.buffer = np.zeros((n_chan, n_points))
self.currentPtr = 0
self.readPtr = 0
self.nUpdate = 0
self.rawData = np.zeros((n_chan, 1))
## append buffer and update current pointer
def appendBuffer(self, data):
if self.nUpdate == self.n_points:
raise Exception("Buffer is full")
n = data.shape[1]
# 计算可以写入的元素数量
write_count = min(self.n_points - self.nUpdate, n)
# 写入新数据
self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count]
# 更新结束指针
self.currentPtr = (self.currentPtr + write_count) % self.n_points
# 更新大小
self.nUpdate += write_count
## get data from buffer
def getData(self, count=50):
# 确保不会尝试读取超过缓冲区当前大小的数据
count = min(count, self.nUpdate)
# 计算读取结束后的下一个位置
next_read_ptr = (self.readPtr + count) % self.n_points
if self.readPtr + count <= self.n_points:
# 情况 1不环绕数据是连续的
end_index = next_read_ptr if next_read_ptr != 0 else self.n_points
data = self.buffer[:, self.readPtr:end_index]
else:
# 情况 2发生环绕数据被分成两部分
# 第一部分:从 readPtr 到缓冲区末尾
part1 = self.buffer[:, self.readPtr:]
# 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点
part2 = self.buffer[:, :next_read_ptr]
# 将两部分在列方向上拼接
data = np.concatenate((part1, part2), axis=1)
# 更新读指针
self.readPtr = next_read_ptr
# 更新大小
self.nUpdate -= count
return data
# reset buffer
def resetAllPara(self):
self.nUpdate = 0
self.currentPtr = 0
self.readPtr = 0 # add by lizhenhua 清空读指针
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
class SunnyLinker64(Thread, ):
t_buffer = 10
n_chan = 64
srate = 250
receiveData = b''
toUv=True#转为uV
RingBufferLock = threading.Lock()
# 单例模式
_instance = None
_initialized = False # 检查是否已经初始化
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(SunnyLinker64, cls).__new__(cls)
return cls._instance
def __init__(self, host='127.0.0.1', port=7878, srate=250, n_chan=64,method = 'tcp'):
if SunnyLinker64._initialized:
return
Thread.__init__(self)
self.daemon = True
self.host = host
self.port = port
self.srate = srate
self.n_chan = n_chan
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
self.__ringBuffer = RingBuffer(self.n_chan + 2,
int(np.round(self.t_buffer * self.srate)))
self.energy = 0 # 电量
self.status_code = 0 # 与采集设备通信的状态码0为异常1为正常
self.gain_value = 6 # 增益倍数
# 设置初始化标志为True防止重复初始化
SunnyLinker64._initialized = True
# --- 新增:用于心跳检测 ---
self.last_called = 0 # 初始化为0
self.last_called_lock = threading.Lock() # 保护 last_called 的访问
def set_sampleRate(self,sampleRate_Code=0x00):
'''
设置采样率
:param sampleRate_Code: 0x00:250Hz,0x01:500Hz,0x02:1000Hz,0x03:2000Hz
'''
function_code = 0x02
gain_code = 0x06
sampleRate_Code = [gain_code,sampleRate_Code]
packed_data = ProtocolFrame.pack(function_code, sampleRate_Code)
if self.method == 'tcp':
self.sock.send(packed_data)
def push_trigger(self,label):
'''
数据打标
@param label:标签类别
'''
function_code = None
label = [label]
packed_data = ProtocolFrame.pack(function_code, label)
if self.method == 'tcp' and hasattr(self,'serial'):
print('发送:', label, datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
self.serial.write(packed_data)
def Impedance(self, On):
'''
阻抗检测开关
:param On:True为开启False为关闭
:return: 组好的协议帧
'''
function_code = 0x01
if On:
data = [0x1]
self.gain_value = 6
else:
data = [0x0]
self.gain_value = 6
packed_data = ProtocolFrame.pack(function_code, data)
if self.method == 'tcp':
self.sock.send(packed_data)
def connect(self):
try:
if self.method == 'serial':
# 开启com口波特率115200超时5
self.sock = serial.Serial(self.host, self.port, timeout=5)
self.sock.flushInput() # 清空缓冲区
count = self.sock.inWaiting() # 获取串口缓冲区数据
while not count:
count = self.sock.inWaiting() # 获取串口缓冲区数据
# # 接收和存储数据
data = (self.sock.read(count))
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
elif self.method == 'tcp':
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((self.host, int(self.port)))
self.set_sampleRate(0x00) #设置250Hz采样率
except Exception as e:
print("请打开头环")
print(e)
print("connected")
def extract_packet(self, packet):
# 存储一个点的八通道数据
dataList = []
# 存储116个点的八通道数据
dataMatrix = []
for j in range(5):
for i in range(self.n_chan):
if not self.toUv:#原始数据直接输出
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
194 * j + 25 + 2 + i * 3]
else:#转为uV
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
194 * j + 25 + 2 + i * 3]
if val < 8388608:
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
else:
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
dataList.append(val)
#同步触发源
val = packet[194 * j + 25 + (i+1) * 3]
dataList.append(val)
#同步触发序号
val = packet[194 * j + 25 + (i+1) * 3+1]
dataList.append(val)
# 将数据矩阵进行拼接
if len(dataMatrix) == 0:
dataMatrix = np.asmatrix(dataList)
else:
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
dataList.clear()
return np.transpose(dataMatrix)
def run(self):
self.connect()
self.running = True
self.PackageLength = 998
# 启动心跳检测线程
threading.Thread(target=self.heartbeat_checker, daemon=True).start()
while self.running:
try:
if self.method == 'serial':
count = self.sock.inWaiting() # 获取串口缓冲区数据
if count:
# 接收和存储数据
data = (self.sock.read(count))
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
elif self.method == 'tcp':
data = self.sock.recv(600)
if not data:
break
self.receiveData += data
with self.last_called_lock:
self.last_called = time.time()
self.status_code = 1 # 收到数据,标记为正常
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
b'\x55\x55') >= self.PackageLength - 2:
index = self.receiveData.index(b'\xaa')
self.receiveData = self.receiveData[index:]
if len(self.receiveData) >= self.PackageLength:
onepackage = self.receiveData[:self.PackageLength]
if onepackage[7] != 0:
self.energy = onepackage[7] # 电量
self.receiveData = self.receiveData[self.PackageLength:]
dataMatrix = self.extract_packet(onepackage)
try:
with self.RingBufferLock:
self.__ringBuffer.appendBuffer(dataMatrix)
except Exception as e:
print("锁:写入异常",e)
# self.RingBufferLock.release()
except ConnectionResetError:
self.status_code = 0 # 状态异常
print("Connection was reset by the peer.")
break
self.sock.close()
# --- 新增:心跳检测线程 ---
def heartbeat_checker(self):
"""
定期检查是否在最近2秒内收到 eegData
如果超过2秒未收到则设置 status_code = 0
"""
while self.running:
time.sleep(0.5) # 每0.5秒检查一次
with self.last_called_lock:
now = time.time()
# 只有收到过一次数据后才开始判断超时
if self.last_called > 0 and (now - self.last_called) > 2:
if self.status_code != 0:
print("EEG data timeout: disconnected")
self.status_code = 0
def getImpedance(self, data,n_chan):
'''
获取阻抗值已经放大100倍单位是kΩ
@param data: 准备计算的通道数据每通道200个值注意不要把信号打标的通道传进来
@return: 返回各个通道的阻抗值
'''
impedanceList = []
data = data[:n_chan]
for channelindex in range(data.shape[0]):
if len(data[channelindex]) > 0:
data_list = []
# 设计陷波滤波器去除50Hz成分
is50filter = True
if is50filter:
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽1000是采样频率
data_list = signal.lfilter(b, a, data[channelindex].tolist())
else:
data_list.extend(data[channelindex].tolist())
data_list = data_list[-1000:]
# 执行FFT
fft_result = np.fft.fft(data_list)
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
# [fft_magnitude[-1] / len(t[0].tolist())]))
# 找到幅值最大的频率成分的索引忽略直流分量即索引0
max_index = np.argmax(fft_magnitude[1:])
# 获取最大幅值的频率索引加上1因为索引0是直流分量
freq_index = max_index + 1
# 获取最大幅值
max_magnitude = fft_magnitude[freq_index]
# 阻抗
import math
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
result *= 0.44 * 100 # 统一放大100倍
impedanceList.append(int(result))
# print(max_magnitude, result)
else:
impedanceList.append(0)
impedances = np.array(impedanceList)
return impedances
def getData(self,count):
'''
获取最新的数据
@param count: 每通道返回的最数值数目
@return: 所有通道的最新count个数值
'''
data=None
try:
with self.RingBufferLock:
data = self.__ringBuffer.getData(count)
except:
print("锁:读取异常")
# self.RingBufferLock.release()
return data
def GetDataLenCount(self):
'''
获取最新缓存中每个通道的数量
@return:
'''
return self.__ringBuffer.nUpdate
def ResetAll(self):
'''
清空缓存
@return:
'''
with self.RingBufferLock:
self.__ringBuffer.resetAllPara()
def stop(self):
self.running = False
if __name__ == "__main__":
# Usage
Linker = SunnyLinker64('127.0.0.1', 5086, 1000, 65)
Linker.start()
try:
while True:
time.sleep(0.005)
if(Linker.count()>0):
# print(Linker.ringBuffer.nUpdate)
t = Linker.getData()
print(t.shape[1], Linker.count())
# Linker.ringBuffer.nUpdate=0
# time.sleep(0.2)
except KeyboardInterrupt:
Linker.stop()