Files
Vetti-Service-new/vetti-admin/src/main/java/com/vetti/socket/ChatWebSocketHandler.java

568 lines
26 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.vetti.socket;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.vetti.common.ai.elevenLabs.ElevenLabsClient;
import com.vetti.common.ai.gpt.OpenAiStreamClient;
import com.vetti.common.ai.gpt.service.OpenAiStreamListenerService;
import com.vetti.common.ai.whisper.WhisperClient;
import com.vetti.common.config.RuoYiConfig;
import com.vetti.common.core.redis.RedisCache;
import com.vetti.common.utils.spring.SpringUtils;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.apache.commons.io.FileUtils;
import org.springframework.stereotype.Component;
import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.*;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
/**
* 语音面试 web处理器
*/
@Slf4j
@ServerEndpoint("/voice-websocket/{clientId}")
@Component
public class ChatWebSocketHandler {
// @Value("${whisper.apiUrl}")
private String API_URL = "wss://api.openai.com/v1/realtime?intent=transcription";
// @Value("${whisper.model}")
private String MODEL = "gpt-4o-mini-transcribe";
// @Value("${whisper.apiKey}")
private String apiKey = "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";
// @Value("${whisper.language}")
private String language = "en";
/**
* 缓存客户端流式解析的语音文本数据
*/
private final Map<String, String> cacheClientTts = new ConcurrentHashMap<>();
/**
* 缓存客户端调用OpenAi中的websocket-STT 流式传输数据
*/
private final Map<String, WebSocket> cacheWebSocket = new ConcurrentHashMap<>();
/**
* 缓存客户端,标记是否是自我介绍后的初次问答
*/
private final Map<String,String> cacheReplyFlag = new ConcurrentHashMap<>();
/**
* 缓存客户端,面试回答信息
*/
private final Map<String,String> cacheMsgMapData = new ConcurrentHashMap<>();
/**
* 缓存客户端,AI提问的问题结果信息
*/
private final Map<String,String> cacheQuestionResult = new ConcurrentHashMap<>();
// 语音文件保存目录
private static final String VOICE_STORAGE_DIR = "/voice_files/";
// 语音结果文件保存目录
private static final String VOICE_STORAGE_RESULT_DIR = "/voice_result_files/";
// 系统语音目录
private static final String VOICE_SYSTEM_DIR = "/system_files/";
public ChatWebSocketHandler() {
// 初始化存储目录
File dir = new File(RuoYiConfig.getProfile() + VOICE_STORAGE_DIR);
if (!dir.exists()) {
dir.mkdirs();
}
File resultDir = new File(RuoYiConfig.getProfile() + VOICE_STORAGE_RESULT_DIR);
if (!resultDir.exists()) {
resultDir.mkdirs();
}
}
// 连接建立时调用
@OnOpen
public void onOpen(Session session, @PathParam("clientId") String clientId) {
log.info("WebSocket 链接已建立:{}", clientId);
log.info("WebSocket session 链接已建立:{}", session.getId());
cacheClientTts.put(clientId, new String());
//初始化STT流式语音转换文本的socket链接
createWhisperRealtimeSocket(session.getId());
//是初次自我介绍后的问答环节
cacheReplyFlag.put(session.getId(),"YES");
//初始化面试回答数据记录
cacheMsgMapData.put(session.getId(),"");
//初始化面试问题
cacheQuestionResult.put(session.getId(),"");
//发送初始化面试官语音流
String openingPathUrl = RuoYiConfig.getProfile() + VOICE_SYSTEM_DIR + "opening.wav";
try {
//文件转换成文件流
ByteBuffer outByteBuffer = convertFileToByteBuffer(openingPathUrl);
//发送文件流数据
session.getBasicRemote().sendBinary(outByteBuffer);
// 发送响应确认
log.info("初始化返回面试官语音信息:{}", System.currentTimeMillis() / 1000);
} catch (IOException e) {
e.printStackTrace();
}
}
// 接收文本消息
@OnMessage
public void onTextMessage(Session session, String message, @PathParam("clientId") String clientId) {
System.out.println("接收到文本消息: " + message);
try {
//处理文本结果
if (StrUtil.isNotEmpty(message)) {
Map<String, String> mapResult = JSONUtil.toBean(JSONUtil.parseObj(message), Map.class);
String resultFlag = mapResult.get("msg");
if ("done".equals(resultFlag)) {
//开始合并语音流
//发送消息
WebSocket webSocket = cacheWebSocket.get(session.getId());
if (webSocket != null) {
webSocket.send("{\"type\": \"input_audio_buffer.commit\"}");
webSocket.send("{\"type\": \"response.create\"}");
}
String startFlag = cacheReplyFlag.get(session.getId());
//语音结束,开始进行回答解析
String cacheResultText = cacheClientTts.get(clientId);
log.info("面试者回答信息为:{}", cacheResultText);
if (StrUtil.isEmpty(cacheResultText)) {
cacheResultText = "Hi.";
}
String promptJson = "";
if("YES".equals(startFlag)) {
//自我介绍结束后马上返回一个Good
//发送初始化面试官语音流
String openingPathUrl = RuoYiConfig.getProfile() + VOICE_SYSTEM_DIR + "good.wav";
try {
//文件转换成文件流
ByteBuffer outByteBuffer = convertFileToByteBuffer(openingPathUrl);
//发送文件流数据
session.getBasicRemote().sendBinary(outByteBuffer);
// 发送响应确认
log.info("初始化返回面试官语音信息:{}", System.currentTimeMillis() / 1000);
} catch (IOException e) {
e.printStackTrace();
}
Map<String,String> mapEntity = new HashMap<>();
mapEntity.put("role","system");
mapEntity.put("content","你是面试官根据Construction Labourer候选人回答生成追问。只要一个问题,问题不要重复");
List<Map<String,String>> list = new LinkedList();
list.add(mapEntity);
promptJson = JSONUtil.toJsonStr(list);
//记录缓存中
cacheMsgMapData.put(session.getId(),promptJson);
}else{
//开始根据面试者回答的问题,进行追问回答
// {
// role: "system",
// content: "你是面试官根据Construction Labourer候选人回答生成追问。"
// },
// {
// role: "user",
// content: `问题:${question}\n候选人回答${answer}`
// }
//获取面试者回答信息
//获取缓存记录
String msgMapData = cacheMsgMapData.get(session.getId());
if(StrUtil.isNotEmpty(msgMapData)){
List<Map> list = JSONUtil.toList(msgMapData, Map.class);
//获取最后一条数据记录
Map<String,String> mapEntity = list.get(list.size()-1);
//更新问题记录
String content = mapEntity.get("content");
mapEntity.put("content", StrUtil.format(content, cacheResultText));
promptJson = JSONUtil.toJsonStr(list);
cacheMsgMapData.put(session.getId(),promptJson);
}
}
//获取完问答数据,直接清空缓存数据
cacheClientTts.put(clientId,"");
cacheReplyFlag.put(session.getId(),"");
//把提问的文字发送给CPT(流式处理)
OpenAiStreamClient aiStreamClient = SpringUtils.getBean(OpenAiStreamClient.class);
log.info("AI提示词为:{}",promptJson);
aiStreamClient.streamChat(promptJson, new OpenAiStreamListenerService() {
@Override
public void onMessage(String content) {
log.info("返回AI结果{}", content);
String questionResult = cacheQuestionResult.get(session.getId());
if(StrUtil.isEmpty(questionResult)){
questionResult = content;
}else{
questionResult = questionResult + content;
}
cacheQuestionResult.put(session.getId(),questionResult);
// 实时输出内容
//开始进行语音输出-流式持续输出
//把结果文字转成语音文件
//生成文件
//生成唯一文件名
String resultFileName = clientId + "_" + System.currentTimeMillis() + ".wav";
String resultPathUrl = RuoYiConfig.getProfile() + VOICE_STORAGE_RESULT_DIR + resultFileName;
ElevenLabsClient elevenLabsClient = SpringUtils.getBean(ElevenLabsClient.class);
elevenLabsClient.handleTextToVoice(content, resultPathUrl);
//持续返回数据流给客户端
try {
//文件转换成文件流
ByteBuffer outByteBuffer = convertFileToByteBuffer(resultPathUrl);
//发送文件流数据
session.getBasicRemote().sendBinary(outByteBuffer);
// 发送响应确认
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void onComplete() {
try {
//开始往缓存中记录提问的问题
String questionResult = cacheQuestionResult.get(session.getId());
//获取缓存记录
String msgMapData = cacheMsgMapData.get(session.getId());
if(StrUtil.isNotEmpty(msgMapData)){
List<Map> list = JSONUtil.toList(msgMapData, Map.class);
Map<String,String> mapEntity = new HashMap<>();
mapEntity.put("role","user");
mapEntity.put("content","问题:"+questionResult+"\\n候选人回答{}");
list.add(mapEntity);
cacheMsgMapData.put(session.getId(),JSONUtil.toJsonStr(list));
}
//清空问题
cacheQuestionResult.put(session.getId(),"");
Map<String, String> resultEntity = new HashMap<>();
resultEntity.put("msg", "done");
//发送通知告诉客户端已经回答结束了
session.getBasicRemote().sendText(JSONUtil.toJsonStr(resultEntity));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void onError(Throwable throwable) {
throwable.printStackTrace();
}
});
}else if("end".equals(resultFlag)){
//发送面试官结束语音流
String openingPathUrl = RuoYiConfig.getProfile() + VOICE_SYSTEM_DIR + "end.wav";
try {
//文件转换成文件流
ByteBuffer outByteBuffer = convertFileToByteBuffer(openingPathUrl);
//发送文件流数据
session.getBasicRemote().sendBinary(outByteBuffer);
// 发送响应确认
log.info("结束返回面试官语音信息:{}", System.currentTimeMillis() / 1000);
} catch (IOException e) {
e.printStackTrace();
}
//返回文本评分
//处理模型提问逻辑
String promptJson = "";
//获取缓存记录
String msgMapData = cacheMsgMapData.get(session.getId());
if(StrUtil.isNotEmpty(msgMapData)){
List<Map> list = JSONUtil.toList(msgMapData, Map.class);
//获取最后一条数据记录
Map<String,String> mapEntity = list.get(0);
//更新问题记录
mapEntity.put("role","system");
mapEntity.put("content","你是建筑行业面试专家对Construction Labourer候选人回答进行1-5分评分。");
promptJson = JSONUtil.toJsonStr(list);
//结束回答要清空问答数据
cacheMsgMapData.put(session.getId(),"");
}
log.info("结束AI提示词为:{}",promptJson);
OpenAiStreamClient aiStreamClient = SpringUtils.getBean(OpenAiStreamClient.class);
aiStreamClient.streamChat(promptJson, new OpenAiStreamListenerService() {
@Override
public void onMessage(String content) {
log.info("返回AI结果{}", content);
try {
//发送文件流数据
session.getBasicRemote().sendText(content);
}catch (Exception e){
e.printStackTrace();
}
}
@Override
public void onComplete() {
}
@Override
public void onError(Throwable throwable) {
throwable.printStackTrace();
}
});
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
// 接收二进制消息(流数据)
@OnMessage
public void onBinaryMessage(Session session, @PathParam("clientId") String clientId, ByteBuffer byteBuffer) {
log.info("客户端ID为:{}", clientId);
// 处理二进制流数据
byte[] bytes = new byte[byteBuffer.remaining()];
//从缓冲区中读取数据并存储到指定的字节数组中
byteBuffer.get(bytes);
// 生成唯一文件名
String fileName = clientId + "_" + System.currentTimeMillis() + ".wav";
String pathUrl = RuoYiConfig.getProfile() + VOICE_STORAGE_DIR + fileName;
log.info("文件路径为:{}", pathUrl);
try {
saveAsWebM(bytes, pathUrl);
//接收到数据流后直接就进行SST处理
//语音格式转换
String fileOutName = clientId + "_" + System.currentTimeMillis() + ".pcm";
String pathOutUrl = RuoYiConfig.getProfile() + VOICE_STORAGE_DIR + fileOutName;
handleAudioToPCM(pathUrl, pathOutUrl);
//发送消息
WebSocket webSocket = cacheWebSocket.get(session.getId());
log.info("获取的socket对象为:{}", webSocket);
if (webSocket != null) {
// 1. 启动音频缓冲
// webSocket.send("{\"type\": \"input_audio_buffer.start\"}");
File outputFile = new File(pathOutUrl); // 输出PCM格式文件
ByteBuffer buffer = ByteBuffer.wrap(FileUtils.readFileToByteArray(outputFile));
byte[] outBytes = new byte[buffer.remaining()];
//从缓冲区中读取数据并存储到指定的字节数组中
buffer.get(outBytes);
String base64Audio = Base64.getEncoder().encodeToString(outBytes);
String message = "{ \"type\": \"input_audio_buffer.append\", \"audio\": \"" + base64Audio + "\" }";
webSocket.send(message);
// 3. 提交音频并请求转录
// webSocket.send("{\"type\": \"input_audio_buffer.commit\"}");
// webSocket.send("{\"type\": \"response.create\"}");
}
} catch (Exception e) {
e.printStackTrace();
}
}
// 连接关闭时调用
@OnClose
public void onClose(Session session, CloseReason reason) {
System.out.println("WebSocket连接已关闭: " + session.getId() + ", 原因: " + reason.getReasonPhrase());
WebSocket webSocket = cacheWebSocket.get(session.getId());
if (webSocket != null) {
webSocket.close(1000, null);
}
}
// 发生错误时调用
@OnError
public void onError(Session session, Throwable throwable) {
System.err.println("WebSocket错误发生: " + throwable.getMessage());
throwable.printStackTrace();
}
/**
* 将字节数组保存为WebM文件
*
* @param byteData 包含WebM数据的字节数组
* @param filePath 目标文件路径
* @return 操作是否成功
*/
private boolean saveAsWebM(byte[] byteData, String filePath) {
// 检查输入参数
if (byteData == null || byteData.length == 0) {
System.err.println("字节数组为空无法生成WebM文件");
return false;
}
if (filePath == null || filePath.trim().isEmpty()) {
System.err.println("文件路径不能为空");
return false;
}
FileOutputStream fos = null;
try {
fos = new FileOutputStream(filePath);
fos.write(byteData);
fos.flush();
System.out.println("WebM文件已成功生成: " + filePath);
return true;
} catch (IOException e) {
System.err.println("写入文件时发生错误: " + e.getMessage());
e.printStackTrace();
} finally {
if (fos != null) {
try {
fos.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return false;
}
/**
* File 转换成 ByteBuffer
*
* @param fileUrl 文件路径
* @return
*/
private ByteBuffer convertFileToByteBuffer(String fileUrl) {
File file = new File(fileUrl);
try {
return ByteBuffer.wrap(FileUtils.readFileToByteArray(file));
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
/**
* 创建STT WebSocket 客户端链接
*
* @param clientId 客户端ID
*/
private void createWhisperRealtimeSocket(String clientId) {
try {
OkHttpClient client = new OkHttpClient();
// 设置 WebSocket 请求
Request request = new Request.Builder()
.url(API_URL)
.addHeader("Authorization", "Bearer " + apiKey)
.addHeader("OpenAI-Beta", "realtime=v1")
.build();
client.newWebSocket(request, new WebSocketListener() {
@Override
public void onOpen(WebSocket webSocket, Response response) {
System.out.println("✅ WebSocket 连接成功");
//发送配置
JSONObject config = new JSONObject();
JSONObject sessionConfig = new JSONObject();
JSONObject transcription = new JSONObject();
JSONObject turnDetection = new JSONObject();
// 配置转录参数
transcription.put("model", MODEL);
transcription.put("language", language); // 中文
// 配置断句检测
turnDetection.put("type", "server_vad");
turnDetection.put("prefix_padding_ms", 300);
turnDetection.put("silence_duration_ms", 10);
// 组装完整配置
sessionConfig.put("input_audio_transcription", transcription);
sessionConfig.put("turn_detection", turnDetection);
config.put("type", "transcription_session.update");
config.put("session", sessionConfig);
webSocket.send(config.toString());
// 1. 启动音频缓冲
// webSocket.send("{\"type\": \"input_audio_buffer.start\"}");
//存储客户端webSocket对象,对数据进行隔离处理
cacheWebSocket.put(clientId, webSocket);
}
@Override
public void onMessage(WebSocket webSocket, String text) {
System.out.println("📩 收到转录结果: " + text);
//对数据进行解析
if (StrUtil.isNotEmpty(text)) {
Map<String, String> mapResultData = JSONUtil.toBean(text, Map.class);
if ("conversation.item.input_audio_transcription.delta".equals(mapResultData.get("type"))) {
String resultText = mapResultData.get("delta");
//进行客户端文本数据存储
String cacheString = cacheClientTts.get(clientId);
if (StrUtil.isNotEmpty(cacheString)) {
cacheString = cacheString + resultText;
} else {
cacheString = resultText;
}
cacheClientTts.put(clientId, cacheString);
}
}
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
System.err.println("❌ 连接失败: " + t.getMessage());
// latch.countDown();
}
@Override
public void onClosing(WebSocket webSocket, int code, String reason) {
System.out.println("⚠️ 连接即将关闭: " + reason);
webSocket.close(1000, null);
// latch.countDown();
}
});
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 语音流文件格式转换
*
* @param pathUrl
* @param outPathUrl
*/
private void handleAudioToPCM(String pathUrl, String outPathUrl) {
File inputFile = new File(pathUrl); // 输入音频文件
File outputFile = new File(outPathUrl); // 输出PCM格式文件
try {
// 读取音频文件
AudioInputStream inputAudioStream = AudioSystem.getAudioInputStream(inputFile);
// 获取音频文件的格式信息
AudioFormat sourceFormat = inputAudioStream.getFormat();
System.out.println("Input Audio Format: " + sourceFormat);
// 设置目标PCM格式 (可以是16-bit, 8kHz, Mono, Linear PCM)
AudioFormat pcmFormat = new AudioFormat(
AudioFormat.Encoding.PCM_SIGNED,
sourceFormat.getSampleRate(),
16, // 16-bit samples
1, // 单声道
2, // 每个样本2字节16位
sourceFormat.getSampleRate(),
false // 大端模式
);
// 获取PCM格式的音频流
AudioInputStream pcmAudioStream = AudioSystem.getAudioInputStream(pcmFormat, inputAudioStream);
// 创建输出文件流
FileOutputStream fos = new FileOutputStream(outputFile);
byte[] buffer = new byte[1024];
int bytesRead;
// 将PCM音频数据写入输出文件
while ((bytesRead = pcmAudioStream.read(buffer)) != -1) {
fos.write(buffer, 0, bytesRead);
}
// 关闭流
pcmAudioStream.close();
fos.close();
System.out.println("Audio has been converted to PCM format and saved at: " + outputFile.getAbsolutePath());
} catch (Exception e) {
e.printStackTrace();
}
}
}