Files
Vetti-Service-new/vetti-admin/src/main/java/com/vetti/socket/ChatWebSocketHandler.java
2025-10-19 09:55:32 +08:00

406 lines
17 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.vetti.socket;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.vetti.common.ai.elevenLabs.ElevenLabsClient;
import com.vetti.common.ai.gpt.OpenAiStreamClient;
import com.vetti.common.ai.gpt.service.OpenAiStreamListenerService;
import com.vetti.common.ai.whisper.WhisperClient;
import com.vetti.common.config.RuoYiConfig;
import com.vetti.common.utils.spring.SpringUtils;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.apache.commons.io.FileUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.sound.sampled.*;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import javax.sound.sampled.*;
import java.io.*;
import java.nio.ByteBuffer;
/**
* 语音面试 web处理器
*/
@Slf4j
@ServerEndpoint("/voice-websocket/{clientId}")
@Component
public class ChatWebSocketHandler {
// @Value("${whisper.apiUrl}")
private String API_URL = "wss://api.openai.com/v1/realtime?intent=transcription";
// @Value("${whisper.model}")
private String MODEL = "gpt-4o-mini-transcribe";
// @Value("${whisper.apiKey}")
private String apiKey = "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";
// @Value("${whisper.language}")
private String language = "en";
/**
* 16kHz
*/
private static final int SAMPLE_RATE = 16000;
/**
* 4 KB 每次读取
*/
private static final int BUFFER_SIZE = 4096;
/**
* 每样本 16 位
*/
private static final int BITS_PER_SAMPLE = 16;
/**
* 缓存客户端流式解析的语音文本数据
*/
private final Map<String,String> cacheClientTts = new ConcurrentHashMap<>();
/**
* 缓存客户端调用OpenAi中的websocket-STT 流式传输数据
*/
private final Map<String, WebSocket> cacheWebSocket = new ConcurrentHashMap<>();
// 语音文件保存目录
private static final String VOICE_STORAGE_DIR = "/voice_files/";
// 语音结果文件保存目录
private static final String VOICE_STORAGE_RESULT_DIR = "/voice_result_files/";
public ChatWebSocketHandler() {
// 初始化存储目录
File dir = new File(RuoYiConfig.getProfile() + VOICE_STORAGE_DIR);
if (!dir.exists()) {
dir.mkdirs();
}
File resultDir = new File(RuoYiConfig.getProfile() + VOICE_STORAGE_RESULT_DIR);
if (!resultDir.exists()) {
resultDir.mkdirs();
}
}
// 连接建立时调用
@OnOpen
public void onOpen(Session session, @PathParam("clientId") String clientId) {
log.info("WebSocket 链接已建立:{}", clientId);
cacheClientTts.put(clientId,new String());
//初始化STT流式语音转换文本的socket链接
createWhisperRealtimeSocket(clientId);
}
// 接收文本消息
@OnMessage
public void onTextMessage(Session session, String message,@PathParam("clientId") String clientId) {
System.out.println("接收到文本消息: " + message);
try {
//处理文本结果
if(StrUtil.isNotEmpty(message)){
Map<String,String> mapResult = JSONUtil.toBean(JSONUtil.parseObj(message),Map.class);
String resultFlag = mapResult.get("msg");
if("done".equals(resultFlag)){
//发送消息
WebSocket webSocket = cacheWebSocket.get(clientId);
if(webSocket != null){
webSocket.close(1000,null);
}
//语音结束,开始进行回答解析
String cacheResultText = cacheClientTts.get(clientId);
log.info("1、开始进行AI回答时间:{}",System.currentTimeMillis()/1000);
//把提问的文字发送给CPT(流式处理)
OpenAiStreamClient aiStreamClient = SpringUtils.getBean(OpenAiStreamClient.class);
aiStreamClient.streamChat(cacheResultText, new OpenAiStreamListenerService() {
@Override
public void onMessage(String content) {
log.info("返回AI结果{}",content);
// 实时输出内容
//开始进行语音输出-流式持续输出
log.info("2、开始进行AI回答时间:{}",System.currentTimeMillis()/1000);
//把结果文字转成语音文件
//生成文件
// 生成唯一文件名
String resultFileName = clientId + "_" + System.currentTimeMillis() + ".opus";
String resultPathUrl = RuoYiConfig.getProfile() + VOICE_STORAGE_RESULT_DIR + resultFileName;
ElevenLabsClient elevenLabsClient = SpringUtils.getBean(ElevenLabsClient.class);
elevenLabsClient.handleTextToVoice(content, resultPathUrl);
log.info("3、开始进行AI回答时间:{}",System.currentTimeMillis()/1000);
//持续返回数据流给客户端
try {
//文件转换成文件流
ByteBuffer outByteBuffer = convertFileToByteBuffer(resultPathUrl);
//发送文件流数据
session.getBasicRemote().sendBinary(outByteBuffer);
// 发送响应确认
log.info("4、开始进行AI回答时间:{}",System.currentTimeMillis()/1000);
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void onComplete() {
try {
Map<String,String> resultEntity = new HashMap<>();
resultEntity.put("msg","done");
//发送通知告诉客户端已经回答结束了
session.getBasicRemote().sendText(JSONUtil.toJsonStr(resultEntity));
} catch (Exception e) {
throw new RuntimeException(e);
}
log.info("5、结束进行AI回答时间:{}",System.currentTimeMillis()/1000);
}
@Override
public void onError(Throwable throwable) {
throwable.printStackTrace();
}
});
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
// 接收二进制消息(流数据)
@OnMessage
public void onBinaryMessage(Session session, @PathParam("clientId") String clientId, ByteBuffer byteBuffer) {
log.info("1、开始接收数据流时间:{}",System.currentTimeMillis()/1000);
log.info("客户端ID为:{}", clientId);
// 处理二进制流数据
byte[] bytes = new byte[byteBuffer.remaining()];
//从缓冲区中读取数据并存储到指定的字节数组中
byteBuffer.get(bytes);
log.info("2、开始接收数据流时间:{}",System.currentTimeMillis()/1000);
// 生成唯一文件名
// String fileName = clientId + "_" + System.currentTimeMillis() + ".webm";
// String pathUrl = RuoYiConfig.getProfile()+VOICE_STORAGE_DIR + fileName;
// log.info("文件路径为:{}", pathUrl);
log.info("3、开始接收数据流时间:{}",System.currentTimeMillis()/1000);
try{
//接收到数据流后直接就进行SST处理
//发送消息
WebSocket webSocket = cacheWebSocket.get(clientId);
log.info("获取的socket对象为:{}",webSocket);
if(webSocket != null){
// 1. 启动音频缓冲
// webSocket.send("{\"type\": \"input_audio_buffer.start\"}");
log.info("3.1 开始发送数据音频流啦");
// 将音频数据转换为 Base64 编码的字符串
//进行转换
// 转换音频格式
AudioFormat format = new AudioFormat(SAMPLE_RATE, BITS_PER_SAMPLE, 1, true, false);
byte[] outputAudioBytes = convertAudio(bytes, format);
String base64Audio = Base64.getEncoder().encodeToString(outputAudioBytes);
String message = "{ \"type\": \"input_audio_buffer.append\", \"audio\": \"" + base64Audio + "\" }";
webSocket.send(message);
log.info("4、开始接收数据流时间:{}",System.currentTimeMillis()/1000);
// 3. 提交音频并请求转录
// webSocket.send("{\"type\": \"input_audio_buffer.commit\"}");
// webSocket.send("{\"type\": \"response.create\"}");
}
}catch (Exception e){
e.printStackTrace();
}
}
// 连接关闭时调用
@OnClose
public void onClose(Session session, CloseReason reason) {
System.out.println("WebSocket连接已关闭: " + session.getId() + ", 原因: " + reason.getReasonPhrase());
}
// 发生错误时调用
@OnError
public void onError(Session session, Throwable throwable) {
System.err.println("WebSocket错误发生: " + throwable.getMessage());
throwable.printStackTrace();
}
public static byte[] convertAudio(byte[] inputAudioBytes, AudioFormat targetFormat) throws Exception {
// 将 byte[] 转换为 AudioInputStream
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(inputAudioBytes);
AudioInputStream inputAudioStream = new AudioInputStream(byteArrayInputStream, targetFormat, inputAudioBytes.length);
// 创建目标格式的 AudioInputStream
AudioInputStream outputAudioStream = AudioSystem.getAudioInputStream(targetFormat, inputAudioStream);
// 获取输出音频的 byte[]
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int bytesRead;
// 从 AudioInputStream 读取数据并写入 ByteArrayOutputStream
while ((bytesRead = outputAudioStream.read(buffer)) != -1) {
byteArrayOutputStream.write(buffer, 0, bytesRead);
}
// 返回转换后的 byte[]
return byteArrayOutputStream.toByteArray();
}
/**
* 将字节数组保存为WebM文件
*
* @param byteData 包含WebM数据的字节数组
* @param filePath 目标文件路径
* @return 操作是否成功
*/
private boolean saveAsWebM(byte[] byteData, String filePath) {
// 检查输入参数
if (byteData == null || byteData.length == 0) {
System.err.println("字节数组为空无法生成WebM文件");
return false;
}
if (filePath == null || filePath.trim().isEmpty()) {
System.err.println("文件路径不能为空");
return false;
}
// 确保文件以.webm结尾
if (!filePath.toLowerCase().endsWith(".webm")) {
filePath += ".webm";
}
FileOutputStream fos = null;
try {
fos = new FileOutputStream(filePath);
fos.write(byteData);
fos.flush();
System.out.println("WebM文件已成功生成: " + filePath);
return true;
} catch (IOException e) {
System.err.println("写入文件时发生错误: " + e.getMessage());
e.printStackTrace();
} finally {
if (fos != null) {
try {
fos.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return false;
}
/**
* File 转换成 ByteBuffer
*
* @param fileUrl 文件路径
* @return
*/
private ByteBuffer convertFileToByteBuffer(String fileUrl) {
File file = new File(fileUrl);
try {
return ByteBuffer.wrap(FileUtils.readFileToByteArray(file));
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
/**
* 创建STT WebSocket 客户端链接
* @param clientId 客户端ID
*/
private void createWhisperRealtimeSocket(String clientId){
try{
OkHttpClient client = new OkHttpClient();
// CountDownLatch latch = new CountDownLatch(1);
// 设置 WebSocket 请求
Request request = new Request.Builder()
.url(API_URL)
.addHeader("Authorization", "Bearer " + apiKey)
.addHeader("OpenAI-Beta", "realtime=v1")
.build();
client.newWebSocket(request, new WebSocketListener() {
@Override
public void onOpen(WebSocket webSocket, Response response) {
System.out.println("✅ WebSocket 连接成功");
//发送配置
JSONObject config = new JSONObject();
JSONObject sessionConfig = new JSONObject();
JSONObject transcription = new JSONObject();
JSONObject turnDetection = new JSONObject();
// 配置转录参数
transcription.put("model", MODEL);
transcription.put("language", language); // 中文
// 配置断句检测
turnDetection.put("type", "server_vad");
turnDetection.put("prefix_padding_ms", 300);
turnDetection.put("silence_duration_ms", 10);
// 组装完整配置
sessionConfig.put("input_audio_transcription", transcription);
sessionConfig.put("turn_detection", turnDetection);
config.put("type", "transcription_session.update");
config.put("session", sessionConfig);
webSocket.send(config.toString());
// 1. 启动音频缓冲
webSocket.send("{\"type\": \"input_audio_buffer.start\"}");
//存储客户端webSocket对象,对数据进行隔离处理
cacheWebSocket.put(clientId,webSocket);
}
@Override
public void onMessage(WebSocket webSocket, String text) {
System.out.println("📩 收到转录结果: " + text);
//对数据进行解析
if(StrUtil.isNotEmpty(text)){
Map<String,String> mapResultData = JSONUtil.toBean(text,Map.class);
if("conversation.item.input_audio_transcription.delta".equals(mapResultData.get("type"))){
String resultText = mapResultData.get("delta");
//进行客户端文本数据存储
String cacheString = cacheClientTts.get(clientId);
if(StrUtil.isNotEmpty(cacheString)){
cacheString = cacheString+resultText;
}else {
cacheString = resultText;
}
cacheClientTts.put(clientId,cacheString);
}
}
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
System.err.println("❌ 连接失败: " + t.getMessage());
// latch.countDown();
}
@Override
public void onClosing(WebSocket webSocket, int code, String reason) {
System.out.println("⚠️ 连接即将关闭: " + reason);
webSocket.close(1000, null);
// latch.countDown();
}
});
// 等待 WebSocket 关闭
// latch.await();
}catch (Exception e){
e.printStackTrace();
}
}
}