456 lines
19 KiB
Java
456 lines
19 KiB
Java
package com.vetti.socket;
|
||
|
||
import cn.hutool.core.util.StrUtil;
|
||
import cn.hutool.json.JSONObject;
|
||
import cn.hutool.json.JSONUtil;
|
||
import com.vetti.common.ai.elevenLabs.ElevenLabsClient;
|
||
import com.vetti.common.ai.gpt.OpenAiStreamClient;
|
||
import com.vetti.common.ai.gpt.service.OpenAiStreamListenerService;
|
||
import com.vetti.common.ai.whisper.WhisperClient;
|
||
import com.vetti.common.config.RuoYiConfig;
|
||
import com.vetti.common.utils.spring.SpringUtils;
|
||
import lombok.extern.slf4j.Slf4j;
|
||
import okhttp3.*;
|
||
import org.apache.commons.io.FileUtils;
|
||
import org.springframework.beans.factory.annotation.Value;
|
||
import org.springframework.stereotype.Component;
|
||
|
||
import javax.sound.sampled.*;
|
||
import javax.websocket.*;
|
||
import javax.websocket.server.PathParam;
|
||
import javax.websocket.server.ServerEndpoint;
|
||
import java.io.File;
|
||
import java.io.FileOutputStream;
|
||
import java.io.IOException;
|
||
import java.nio.ByteBuffer;
|
||
import java.util.Base64;
|
||
import java.util.HashMap;
|
||
import java.util.Map;
|
||
import java.util.concurrent.ConcurrentHashMap;
|
||
import java.util.concurrent.CountDownLatch;
|
||
import javax.sound.sampled.*;
|
||
import java.io.*;
|
||
import java.nio.ByteBuffer;
|
||
|
||
/**
|
||
* 语音面试 web处理器
|
||
*/
|
||
@Slf4j
|
||
@ServerEndpoint("/voice-websocket/{clientId}")
|
||
@Component
|
||
public class ChatWebSocketHandler {
|
||
|
||
// @Value("${whisper.apiUrl}")
|
||
private String API_URL = "wss://api.openai.com/v1/realtime?intent=transcription";
|
||
|
||
// @Value("${whisper.model}")
|
||
private String MODEL = "gpt-4o-mini-transcribe";
|
||
|
||
// @Value("${whisper.apiKey}")
|
||
private String apiKey = "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";
|
||
|
||
// @Value("${whisper.language}")
|
||
private String language = "en";
|
||
|
||
/**
|
||
* 16kHz
|
||
*/
|
||
private static final int SAMPLE_RATE = 16000;
|
||
/**
|
||
* 4 KB 每次读取
|
||
*/
|
||
private static final int BUFFER_SIZE = 4096;
|
||
/**
|
||
* 每样本 16 位
|
||
*/
|
||
private static final int BITS_PER_SAMPLE = 16;
|
||
|
||
/**
|
||
* 缓存客户端流式解析的语音文本数据
|
||
*/
|
||
private final Map<String,String> cacheClientTts = new ConcurrentHashMap<>();
|
||
|
||
/**
|
||
* 缓存客户端调用OpenAi中的websocket-STT 流式传输数据
|
||
*/
|
||
private final Map<String, WebSocket> cacheWebSocket = new ConcurrentHashMap<>();
|
||
|
||
// 语音文件保存目录
|
||
private static final String VOICE_STORAGE_DIR = "/voice_files/";
|
||
|
||
// 语音结果文件保存目录
|
||
private static final String VOICE_STORAGE_RESULT_DIR = "/voice_result_files/";
|
||
|
||
|
||
public ChatWebSocketHandler() {
|
||
// 初始化存储目录
|
||
File dir = new File(RuoYiConfig.getProfile() + VOICE_STORAGE_DIR);
|
||
if (!dir.exists()) {
|
||
dir.mkdirs();
|
||
}
|
||
|
||
File resultDir = new File(RuoYiConfig.getProfile() + VOICE_STORAGE_RESULT_DIR);
|
||
if (!resultDir.exists()) {
|
||
resultDir.mkdirs();
|
||
}
|
||
}
|
||
|
||
// 连接建立时调用
|
||
@OnOpen
|
||
public void onOpen(Session session, @PathParam("clientId") String clientId) {
|
||
log.info("WebSocket 链接已建立:{}", clientId);
|
||
cacheClientTts.put(clientId,new String());
|
||
//初始化STT流式语音转换文本的socket链接
|
||
createWhisperRealtimeSocket(clientId);
|
||
}
|
||
|
||
// 接收文本消息
|
||
@OnMessage
|
||
public void onTextMessage(Session session, String message,@PathParam("clientId") String clientId) {
|
||
System.out.println("接收到文本消息: " + message);
|
||
try {
|
||
//处理文本结果
|
||
if(StrUtil.isNotEmpty(message)){
|
||
Map<String,String> mapResult = JSONUtil.toBean(JSONUtil.parseObj(message),Map.class);
|
||
String resultFlag = mapResult.get("msg");
|
||
if("done".equals(resultFlag)){
|
||
|
||
//发送消息
|
||
WebSocket webSocket = cacheWebSocket.get(clientId);
|
||
webSocket.send("{\"type\": \"input_audio_buffer.commit\"}");
|
||
webSocket.send("{\"type\": \"response.create\"}");
|
||
// if(webSocket != null){
|
||
// webSocket.close(1000,null);
|
||
// }
|
||
//语音结束,开始进行回答解析
|
||
String cacheResultText = cacheClientTts.get(clientId);
|
||
log.info("1、开始进行AI回答时间:{}",System.currentTimeMillis()/1000);
|
||
//把提问的文字发送给CPT(流式处理)
|
||
OpenAiStreamClient aiStreamClient = SpringUtils.getBean(OpenAiStreamClient.class);
|
||
aiStreamClient.streamChat(cacheResultText, new OpenAiStreamListenerService() {
|
||
@Override
|
||
public void onMessage(String content) {
|
||
log.info("返回AI结果:{}",content.trim());
|
||
// 实时输出内容
|
||
//开始进行语音输出-流式持续输出
|
||
log.info("2、开始进行AI回答时间:{}",System.currentTimeMillis()/1000);
|
||
//把结果文字转成语音文件
|
||
//生成文件
|
||
//生成唯一文件名
|
||
String resultFileName = clientId + "_" + System.currentTimeMillis() + ".wav";
|
||
String resultPathUrl = RuoYiConfig.getProfile() + VOICE_STORAGE_RESULT_DIR + resultFileName;
|
||
ElevenLabsClient elevenLabsClient = SpringUtils.getBean(ElevenLabsClient.class);
|
||
elevenLabsClient.handleTextToVoice(content.trim(), resultPathUrl);
|
||
log.info("3、开始进行AI回答时间:{}",System.currentTimeMillis()/1000);
|
||
//持续返回数据流给客户端
|
||
try {
|
||
String resultOutPathUrl = RuoYiConfig.getProfile() + VOICE_STORAGE_RESULT_DIR + "110_"+resultFileName;
|
||
handleVoice(resultPathUrl,resultOutPathUrl);
|
||
//文件转换成文件流
|
||
ByteBuffer outByteBuffer = convertFileToByteBuffer(resultOutPathUrl);
|
||
//发送文件流数据
|
||
session.getBasicRemote().sendBinary(outByteBuffer);
|
||
// 发送响应确认
|
||
log.info("4、开始进行AI回答时间:{}",System.currentTimeMillis()/1000);
|
||
} catch (IOException e) {
|
||
e.printStackTrace();
|
||
}
|
||
}
|
||
|
||
@Override
|
||
public void onComplete() {
|
||
try {
|
||
Map<String,String> resultEntity = new HashMap<>();
|
||
resultEntity.put("msg","done");
|
||
//发送通知告诉客户端已经回答结束了
|
||
session.getBasicRemote().sendText(JSONUtil.toJsonStr(resultEntity));
|
||
} catch (Exception e) {
|
||
throw new RuntimeException(e);
|
||
}
|
||
log.info("5、结束进行AI回答时间:{}",System.currentTimeMillis()/1000);
|
||
}
|
||
|
||
@Override
|
||
public void onError(Throwable throwable) {
|
||
throwable.printStackTrace();
|
||
}
|
||
});
|
||
}
|
||
}
|
||
} catch (Exception e) {
|
||
e.printStackTrace();
|
||
}
|
||
}
|
||
|
||
// 接收二进制消息(流数据)
|
||
@OnMessage
|
||
public void onBinaryMessage(Session session, @PathParam("clientId") String clientId, ByteBuffer byteBuffer) {
|
||
log.info("1、开始接收数据流时间:{}",System.currentTimeMillis()/1000);
|
||
log.info("客户端ID为:{}", clientId);
|
||
// 处理二进制流数据
|
||
byte[] bytes = new byte[byteBuffer.remaining()];
|
||
//从缓冲区中读取数据并存储到指定的字节数组中
|
||
byteBuffer.get(bytes);
|
||
log.info("2、开始接收数据流时间:{}",System.currentTimeMillis()/1000);
|
||
// 生成唯一文件名
|
||
// String fileName = clientId + "_" + System.currentTimeMillis() + ".webm";
|
||
// String pathUrl = RuoYiConfig.getProfile()+VOICE_STORAGE_DIR + fileName;
|
||
// log.info("文件路径为:{}", pathUrl);
|
||
log.info("3、开始接收数据流时间:{}",System.currentTimeMillis()/1000);
|
||
try{
|
||
//接收到数据流后直接就进行SST处理
|
||
//发送消息
|
||
WebSocket webSocket = cacheWebSocket.get(clientId);
|
||
log.info("获取的socket对象为:{}",webSocket);
|
||
if(webSocket != null){
|
||
// 1. 启动音频缓冲
|
||
// webSocket.send("{\"type\": \"input_audio_buffer.start\"}");
|
||
log.info("3.1 开始发送数据音频流啦");
|
||
// 将音频数据转换为 Base64 编码的字符串
|
||
//进行转换
|
||
// 转换音频格式
|
||
AudioFormat format = new AudioFormat(SAMPLE_RATE, BITS_PER_SAMPLE, 1, true, false);
|
||
byte[] outputAudioBytes = convertAudio(bytes, format);
|
||
String base64Audio = Base64.getEncoder().encodeToString(outputAudioBytes);
|
||
String message = "{ \"type\": \"input_audio_buffer.append\", \"audio\": \"" + base64Audio + "\" }";
|
||
webSocket.send(message);
|
||
log.info("4、开始接收数据流时间:{}",System.currentTimeMillis()/1000);
|
||
// 3. 提交音频并请求转录
|
||
// webSocket.send("{\"type\": \"input_audio_buffer.commit\"}");
|
||
// webSocket.send("{\"type\": \"response.create\"}");
|
||
}
|
||
}catch (Exception e){
|
||
e.printStackTrace();
|
||
}
|
||
|
||
}
|
||
|
||
// 连接关闭时调用
|
||
@OnClose
|
||
public void onClose(Session session, CloseReason reason) {
|
||
System.out.println("WebSocket连接已关闭: " + session.getId() + ", 原因: " + reason.getReasonPhrase());
|
||
}
|
||
|
||
// 发生错误时调用
|
||
@OnError
|
||
public void onError(Session session, Throwable throwable) {
|
||
System.err.println("WebSocket错误发生: " + throwable.getMessage());
|
||
throwable.printStackTrace();
|
||
}
|
||
|
||
public static byte[] convertAudio(byte[] inputAudioBytes, AudioFormat targetFormat) throws Exception {
|
||
// 将 byte[] 转换为 AudioInputStream
|
||
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(inputAudioBytes);
|
||
AudioInputStream inputAudioStream = new AudioInputStream(byteArrayInputStream, targetFormat, inputAudioBytes.length);
|
||
|
||
// 创建目标格式的 AudioInputStream
|
||
AudioInputStream outputAudioStream = AudioSystem.getAudioInputStream(targetFormat, inputAudioStream);
|
||
|
||
// 获取输出音频的 byte[]
|
||
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
||
byte[] buffer = new byte[1024];
|
||
int bytesRead;
|
||
|
||
// 从 AudioInputStream 读取数据并写入 ByteArrayOutputStream
|
||
while ((bytesRead = outputAudioStream.read(buffer)) != -1) {
|
||
byteArrayOutputStream.write(buffer, 0, bytesRead);
|
||
}
|
||
|
||
// 返回转换后的 byte[]
|
||
return byteArrayOutputStream.toByteArray();
|
||
}
|
||
|
||
/**
|
||
* 将字节数组保存为WebM文件
|
||
*
|
||
* @param byteData 包含WebM数据的字节数组
|
||
* @param filePath 目标文件路径
|
||
* @return 操作是否成功
|
||
*/
|
||
private boolean saveAsWebM(byte[] byteData, String filePath) {
|
||
// 检查输入参数
|
||
if (byteData == null || byteData.length == 0) {
|
||
System.err.println("字节数组为空,无法生成WebM文件");
|
||
return false;
|
||
}
|
||
|
||
if (filePath == null || filePath.trim().isEmpty()) {
|
||
System.err.println("文件路径不能为空");
|
||
return false;
|
||
}
|
||
|
||
// 确保文件以.webm结尾
|
||
if (!filePath.toLowerCase().endsWith(".webm")) {
|
||
filePath += ".webm";
|
||
}
|
||
|
||
FileOutputStream fos = null;
|
||
try {
|
||
fos = new FileOutputStream(filePath);
|
||
fos.write(byteData);
|
||
fos.flush();
|
||
System.out.println("WebM文件已成功生成: " + filePath);
|
||
return true;
|
||
} catch (IOException e) {
|
||
System.err.println("写入文件时发生错误: " + e.getMessage());
|
||
e.printStackTrace();
|
||
} finally {
|
||
if (fos != null) {
|
||
try {
|
||
fos.close();
|
||
} catch (IOException e) {
|
||
e.printStackTrace();
|
||
}
|
||
}
|
||
}
|
||
return false;
|
||
}
|
||
|
||
/**
|
||
* File 转换成 ByteBuffer
|
||
*
|
||
* @param fileUrl 文件路径
|
||
* @return
|
||
*/
|
||
private ByteBuffer convertFileToByteBuffer(String fileUrl) {
|
||
File file = new File(fileUrl);
|
||
try {
|
||
return ByteBuffer.wrap(FileUtils.readFileToByteArray(file));
|
||
} catch (Exception e) {
|
||
e.printStackTrace();
|
||
}
|
||
return null;
|
||
}
|
||
|
||
/**
|
||
* 创建STT WebSocket 客户端链接
|
||
* @param clientId 客户端ID
|
||
*/
|
||
private void createWhisperRealtimeSocket(String clientId){
|
||
try{
|
||
OkHttpClient client = new OkHttpClient();
|
||
// CountDownLatch latch = new CountDownLatch(1);
|
||
// 设置 WebSocket 请求
|
||
Request request = new Request.Builder()
|
||
.url(API_URL)
|
||
.addHeader("Authorization", "Bearer " + apiKey)
|
||
.addHeader("OpenAI-Beta", "realtime=v1")
|
||
.build();
|
||
client.newWebSocket(request, new WebSocketListener() {
|
||
@Override
|
||
public void onOpen(WebSocket webSocket, Response response) {
|
||
System.out.println("✅ WebSocket 连接成功");
|
||
//发送配置
|
||
JSONObject config = new JSONObject();
|
||
JSONObject sessionConfig = new JSONObject();
|
||
JSONObject transcription = new JSONObject();
|
||
JSONObject turnDetection = new JSONObject();
|
||
// 配置转录参数
|
||
transcription.put("model", MODEL);
|
||
transcription.put("language", language); // 中文
|
||
// 配置断句检测
|
||
turnDetection.put("type", "server_vad");
|
||
turnDetection.put("prefix_padding_ms", 300);
|
||
turnDetection.put("silence_duration_ms", 10);
|
||
// 组装完整配置
|
||
sessionConfig.put("input_audio_transcription", transcription);
|
||
sessionConfig.put("turn_detection", turnDetection);
|
||
config.put("type", "transcription_session.update");
|
||
config.put("session", sessionConfig);
|
||
webSocket.send(config.toString());
|
||
|
||
// 1. 启动音频缓冲
|
||
webSocket.send("{\"type\": \"input_audio_buffer.start\"}");
|
||
|
||
//存储客户端webSocket对象,对数据进行隔离处理
|
||
cacheWebSocket.put(clientId,webSocket);
|
||
}
|
||
|
||
@Override
|
||
public void onMessage(WebSocket webSocket, String text) {
|
||
System.out.println("📩 收到转录结果: " + text);
|
||
//对数据进行解析
|
||
if(StrUtil.isNotEmpty(text)){
|
||
Map<String,String> mapResultData = JSONUtil.toBean(text,Map.class);
|
||
if("conversation.item.input_audio_transcription.delta".equals(mapResultData.get("type"))){
|
||
String resultText = mapResultData.get("delta");
|
||
//进行客户端文本数据存储
|
||
String cacheString = cacheClientTts.get(clientId);
|
||
if(StrUtil.isNotEmpty(cacheString)){
|
||
cacheString = cacheString+resultText;
|
||
}else {
|
||
cacheString = resultText;
|
||
}
|
||
cacheClientTts.put(clientId,cacheString);
|
||
}
|
||
}
|
||
}
|
||
|
||
@Override
|
||
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
|
||
System.err.println("❌ 连接失败: " + t.getMessage());
|
||
// latch.countDown();
|
||
}
|
||
|
||
@Override
|
||
public void onClosing(WebSocket webSocket, int code, String reason) {
|
||
System.out.println("⚠️ 连接即将关闭: " + reason);
|
||
webSocket.close(1000, null);
|
||
// latch.countDown();
|
||
}
|
||
});
|
||
// 等待 WebSocket 关闭
|
||
// latch.await();
|
||
}catch (Exception e){
|
||
e.printStackTrace();
|
||
}
|
||
}
|
||
|
||
private void handleVoice(String inputPath,String outputPath){
|
||
double trimMs = 270; // 要去掉的尾部时长(毫秒)
|
||
|
||
try {
|
||
// 1. 解析音频格式和总长度
|
||
AudioInputStream audioIn = AudioSystem.getAudioInputStream(new File(inputPath));
|
||
AudioFormat format = audioIn.getFormat();
|
||
long totalBytes = audioIn.getFrameLength() * format.getFrameSize(); // 总字节数
|
||
|
||
// 2. 计算300毫秒对应的字节数
|
||
float sampleRate = format.getSampleRate(); // 采样率(Hz)
|
||
int frameSize = format.getFrameSize(); // 每帧字节数(位深/8 * 声道数)
|
||
double trimSeconds = trimMs / 1000.0; // 转换为秒
|
||
long trimBytes = (long) (sampleRate * trimSeconds * frameSize); // 要去掉的字节数
|
||
|
||
// 3. 计算需要保留的字节数(避免负数)
|
||
long keepBytes = Math.max(0, totalBytes - trimBytes);
|
||
if (keepBytes == 0) {
|
||
System.out.println("音频长度小于300毫秒,无法截断");
|
||
return;
|
||
}
|
||
|
||
// 4. 读取并保留前半部分(去掉最后300毫秒)
|
||
try (InputStream in = new FileInputStream(inputPath);
|
||
OutputStream out = new FileOutputStream(outputPath)) {
|
||
|
||
byte[] buffer = new byte[4096];
|
||
long totalRead = 0;
|
||
int bytesRead;
|
||
|
||
while (totalRead < keepBytes && (bytesRead = in.read(buffer)) != -1) {
|
||
long remaining = keepBytes - totalRead;
|
||
int writeBytes = (remaining < bytesRead) ? (int) remaining : bytesRead;
|
||
out.write(buffer, 0, writeBytes);
|
||
totalRead += writeBytes;
|
||
}
|
||
|
||
System.out.println("处理完成,去掉了最后" + trimMs + "毫秒,保留了" + totalRead + "字节");
|
||
}
|
||
|
||
} catch (UnsupportedAudioFileException | IOException e) {
|
||
e.printStackTrace();
|
||
}
|
||
}
|
||
|
||
}
|
||
|