diff --git a/vetti-admin/src/main/java/com/vetti/socket/FileReceiverConfig.java b/vetti-admin/src/main/java/com/vetti/socket/FileReceiverConfig.java index 30a7d95..a56220c 100644 --- a/vetti-admin/src/main/java/com/vetti/socket/FileReceiverConfig.java +++ b/vetti-admin/src/main/java/com/vetti/socket/FileReceiverConfig.java @@ -21,7 +21,7 @@ public class FileReceiverConfig implements WebSocketConfigurer { @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { - registry.addHandler(fileReceiverWebSocketHandler(), "/voice-websocket") + registry.addHandler(fileReceiverWebSocketHandler(), "/voice-websocket111111") .addInterceptors(handshakeInterceptor()) .setAllowedOrigins("*"); // 生产环境需替换为具体域名 } diff --git a/vetti-admin/src/main/java/com/vetti/socket/VoiceWebSocketHandler.java b/vetti-admin/src/main/java/com/vetti/socket/VoiceWebSocketHandler.java index b3157b7..58a982c 100644 --- a/vetti-admin/src/main/java/com/vetti/socket/VoiceWebSocketHandler.java +++ b/vetti-admin/src/main/java/com/vetti/socket/VoiceWebSocketHandler.java @@ -1,143 +1,170 @@ package com.vetti.socket; -import com.vetti.socket.vo.FileMetadata; +import cn.hutool.json.JSONUtil; import com.vetti.socket.vo.FileTransferState; +import com.vetti.socket.vo.VoicePartMessage; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; -import org.springframework.web.socket.*; +import org.springframework.web.socket.BinaryMessage; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.*; import java.nio.ByteBuffer; -import java.nio.file.*; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.*; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; +import java.util.Base64; @Slf4j @Component public class VoiceWebSocketHandler extends TextWebSocketHandler { - // 存储客户端的文件传输状态:clientId -> FileTransferState - private final Map transferStates = new ConcurrentHashMap<>(); - private final ObjectMapper objectMapper = new ObjectMapper(); - private static final String STORAGE_DIR = "received_files/"; - // 初始化存储目录 - static { - try { - Files.createDirectories(Paths.get(STORAGE_DIR)); - } catch (IOException e) { - throw new RuntimeException("无法创建文件存储目录", e); + // 存储每个客户端的语音分片,key: clientId, value: 分片映射 + private final Map> clientVoiceParts = new ConcurrentHashMap<>(); + // 存储每个客户端的总分片数,key: clientId + private final Map clientTotalParts = new ConcurrentHashMap<>(); + // 用于并发控制的锁 + private final Map clientLocks = new ConcurrentHashMap<>(); + // JSON序列化工具 + private final ObjectMapper objectMapper = new ObjectMapper(); + // 语音文件保存目录 + private static final String VOICE_STORAGE_DIR = "voice_files/"; + + public VoiceWebSocketHandler() { + // 初始化存储目录 + File dir = new File(VOICE_STORAGE_DIR); + if (!dir.exists()) { + dir.mkdirs(); } } @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { - String clientId = (String) session.getAttributes().get("clientId"); - transferStates.put(clientId, new FileTransferState()); - System.out.println("客户端连接: " + clientId); + String clientId = getClientId(session); + if (clientId != null) { + // 初始化客户端数据结构 + clientVoiceParts.put(clientId, new TreeMap<>()); // TreeMap保证分片有序 + clientLocks.putIfAbsent(clientId, new ReentrantLock()); + System.out.println("客户端连接建立: " + clientId); + } } - // 处理文本消息(文件元数据) @Override protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { - String clientId = (String) session.getAttributes().get("clientId"); - FileMetadata metadata = objectMapper.readValue(message.getPayload(), FileMetadata.class); + log.info("开始进入文本传输里面了"); + String clientId = getClientId(session); + if (clientId == null) { + System.err.println("无法获取客户端ID"); + return; + } - // 初始化文件传输状态 - FileTransferState state = transferStates.get(clientId); - state.setFileName(metadata.getFileName()); - state.setTotalSize(metadata.getTotalSize()); - state.setTotalParts(metadata.getTotalParts()); - state.setOutputStream(new FileOutputStream(STORAGE_DIR + metadata.getFileName())); + try { + // 解析前端发送的JSON消息 + VoicePartMessage voiceMessage = objectMapper.readValue(message.getPayload(), VoicePartMessage.class); - System.out.println("开始接收文件: " + metadata.getFileName() + " (" + metadata.getTotalParts() + "个分片)"); - - // 确认已收到元数据 - session.sendMessage(new TextMessage("{\"type\":\"metadata_ack\"}")); + // 处理语音分片 + if ("voice_part".equals(voiceMessage.getType())) { + processVoicePart(clientId, voiceMessage, session); + } + } catch (Exception e) { + System.err.println("处理消息出错: " + e.getMessage()); + e.printStackTrace(); + } } // 处理二进制消息(文件分片) @Override protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message){ - try{ - log.info("开始-接收文件分片数据流"); - String clientId = (String) session.getAttributes().get("clientId"); - FileTransferState state = transferStates.get(clientId); - - if (state == null || state.getOutputStream() == null) { - session.sendMessage(new TextMessage("{\"type\":\"error\", \"message\":\"未收到文件元数据\"}")); - return; - } - log.info("进行中-接收文件分片数据流"); - // 解析分片数据 - ByteBuffer payload = message.getPayload(); - // int partNumber = payload.getInt(); // 前4字节是分片编号 - byte[] data = new byte[payload.remaining()]; - payload.get(data); - - // 写入文件 - state.getOutputStream().write(data); - state.incrementReceivedParts(); - - // 发送进度更新(每5个分片或最后一个分片) - if (state.getReceivedParts() % 5 == 0 || state.getReceivedParts() == state.getTotalParts()) { - // 检查是否接收完成 - if (state.getReceivedParts() == state.getTotalParts()) { - log.info("生成完整的文件-接收文件分片数据流"); - completeFileTransfer(session, state, clientId); - //进行文件数据转换 - - //获取最终的文件结果 - - //把文件转成对应的文件流,返回给前端 -// session.sendMessage(new BinaryMessage()); - } - } - }catch (Exception e){ - e.printStackTrace(); - } - } - - // 完成文件传输 - private void completeFileTransfer(WebSocketSession session, FileTransferState state, String clientId) throws IOException { - // 关闭文件输出流 - state.getOutputStream().close(); - // 验证文件大小 - File file = new File(STORAGE_DIR + state.getFileName()); - boolean fileValid = file.length() == state.getTotalSize(); - // 发送完成消息 - String result = fileValid ? - "{\"type\":\"complete\", \"message\":\"文件接收完成\", \"filePath\":\"" + file.getAbsolutePath() + "\"}" : - "{\"type\":\"error\", \"message\":\"文件损坏,大小不匹配\"}"; - session.sendMessage(new TextMessage(result)); - - System.out.println("文件接收" + (fileValid ? "完成" : "失败") + ": " + state.getFileName()); - - // 清理状态 - transferStates.remove(clientId); + log.info("开始进入文件流传输里面了"); + log.info("获取的数据为:{}", JSONUtil.toJsonStr(message)); } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { - String clientId = (String) session.getAttributes().get("clientId"); - FileTransferState state = transferStates.remove(clientId); - - // 关闭可能存在的文件流 - if (state != null && state.getOutputStream() != null) { - try { - state.getOutputStream().close(); - // 删除未完成的文件 - Files.deleteIfExists(Paths.get(STORAGE_DIR + state.getFileName())); - } catch (IOException e) { - e.printStackTrace(); - } + String clientId = getClientId(session); + if (clientId != null) { + // 清理客户端资源 + clientVoiceParts.remove(clientId); + clientTotalParts.remove(clientId); + clientLocks.remove(clientId); + System.out.println("客户端连接关闭: " + clientId); } - System.out.println("客户端断开连接: " + clientId); } - @Override - public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { - System.err.println("传输错误: " + exception.getMessage()); - session.close(CloseStatus.SERVER_ERROR); + /** + * 处理语音分片 + */ + private void processVoicePart(String clientId, VoicePartMessage message, WebSocketSession session) throws Exception { + ReentrantLock lock = clientLocks.get(clientId); + lock.lock(); // 加锁确保线程安全 + try { + // 保存总分片数 + clientTotalParts.put(clientId, message.getTotalParts()); + + // 解码Base64数据并存储分片 + byte[] voiceData = Base64.getDecoder().decode(message.getData()); + clientVoiceParts.get(clientId).put(message.getPartNumber(), voiceData); + + System.out.printf("接收客户端 %s 的分片 %d/%d%n", + clientId, message.getPartNumber() + 1, message.getTotalParts()); + + // 检查是否所有分片都已接收 + checkAndMergeParts(clientId, session); + } finally { + lock.unlock(); // 释放锁 + } + } + + /** + * 检查是否所有分片都已接收,如果是则合并 + */ + private void checkAndMergeParts(String clientId, WebSocketSession session) throws Exception { + Map parts = clientVoiceParts.get(clientId); + Integer totalParts = clientTotalParts.get(clientId); + + if (parts == null || totalParts == null) { + return; + } + + // 所有分片都已接收 + if (parts.size() == totalParts) { + System.out.println("所有分片接收完成,开始合并: " + clientId); + + // 生成唯一文件名 + String fileName = clientId + "_" + System.currentTimeMillis() + ".wav"; + Path outputPath = Paths.get(VOICE_STORAGE_DIR + fileName); + + // 合并分片 + try (FileOutputStream fos = new FileOutputStream(outputPath.toFile())) { + for (byte[] part : parts.values()) { + fos.write(part); + } + } + + System.out.println("语音文件合并完成,保存路径: " + outputPath); + + // 向客户端发送处理完成消息 + Map response = new HashMap<>(); + response.put("type", "complete"); + response.put("message", "语音接收完成"); + response.put("fileName", fileName); + session.sendMessage(new TextMessage(objectMapper.writeValueAsString(response))); + + // 清理已合并的分片数据 + clientVoiceParts.get(clientId).clear(); + } + } + + /** + * 从会话中获取客户端ID + */ + private String getClientId(WebSocketSession session) { + return (String) session.getAttributes().get("clientId"); } } diff --git a/vetti-admin/src/main/java/com/vetti/socket/WebSocketConfig.java b/vetti-admin/src/main/java/com/vetti/socket/WebSocketConfig.java index 2657e60..0b52f1c 100644 --- a/vetti-admin/src/main/java/com/vetti/socket/WebSocketConfig.java +++ b/vetti-admin/src/main/java/com/vetti/socket/WebSocketConfig.java @@ -24,7 +24,7 @@ public class WebSocketConfig implements WebSocketConfigurer { @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { // 注册WebSocket处理器,设置路径和允许跨域 - registry.addHandler(voiceWebSocketHandler, "/voice-websocket123") + registry.addHandler(voiceWebSocketHandler, "/voice-websocket") .addInterceptors(voiceHandshakeInterceptor) .setAllowedOrigins("*"); // 生产环境应指定具体域名而非* }