websocket 生成完整的文件

This commit is contained in:
wangxiangshun
2025-10-05 15:03:38 +08:00
parent eeeb176098
commit 67c22e3082
3 changed files with 131 additions and 104 deletions

View File

@@ -21,7 +21,7 @@ public class FileReceiverConfig implements WebSocketConfigurer {
@Override @Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(fileReceiverWebSocketHandler(), "/voice-websocket") registry.addHandler(fileReceiverWebSocketHandler(), "/voice-websocket111111")
.addInterceptors(handshakeInterceptor()) .addInterceptors(handshakeInterceptor())
.setAllowedOrigins("*"); // 生产环境需替换为具体域名 .setAllowedOrigins("*"); // 生产环境需替换为具体域名
} }

View File

@@ -1,143 +1,170 @@
package com.vetti.socket; package com.vetti.socket;
import com.vetti.socket.vo.FileMetadata; import cn.hutool.json.JSONUtil;
import com.vetti.socket.vo.FileTransferState; import com.vetti.socket.vo.FileTransferState;
import com.vetti.socket.vo.VoicePartMessage;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.web.socket.*; import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler; import org.springframework.web.socket.handler.TextWebSocketHandler;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.*; import java.io.*;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.file.*; import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*; import java.util.*;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;
import java.util.Base64;
@Slf4j @Slf4j
@Component @Component
public class VoiceWebSocketHandler extends TextWebSocketHandler { public class VoiceWebSocketHandler extends TextWebSocketHandler {
// 存储客户端的文件传输状态clientId -> FileTransferState
private final Map<String, FileTransferState> transferStates = new ConcurrentHashMap<>();
private final ObjectMapper objectMapper = new ObjectMapper();
private static final String STORAGE_DIR = "received_files/";
// 存储每个客户端的语音分片key: clientId, value: 分片映射
private final Map<String, Map<Integer, byte[]>> clientVoiceParts = new ConcurrentHashMap<>();
// 存储每个客户端的总分片数key: clientId
private final Map<String, Integer> clientTotalParts = new ConcurrentHashMap<>();
// 用于并发控制的锁
private final Map<String, ReentrantLock> clientLocks = new ConcurrentHashMap<>();
// JSON序列化工具
private final ObjectMapper objectMapper = new ObjectMapper();
// 语音文件保存目录
private static final String VOICE_STORAGE_DIR = "voice_files/";
public VoiceWebSocketHandler() {
// 初始化存储目录 // 初始化存储目录
static { File dir = new File(VOICE_STORAGE_DIR);
try { if (!dir.exists()) {
Files.createDirectories(Paths.get(STORAGE_DIR)); dir.mkdirs();
} catch (IOException e) {
throw new RuntimeException("无法创建文件存储目录", e);
} }
} }
@Override @Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception { public void afterConnectionEstablished(WebSocketSession session) throws Exception {
String clientId = (String) session.getAttributes().get("clientId"); String clientId = getClientId(session);
transferStates.put(clientId, new FileTransferState()); if (clientId != null) {
System.out.println("客户端连接: " + clientId); // 初始化客户端数据结构
clientVoiceParts.put(clientId, new TreeMap<>()); // TreeMap保证分片有序
clientLocks.putIfAbsent(clientId, new ReentrantLock());
System.out.println("客户端连接建立: " + clientId);
}
} }
// 处理文本消息(文件元数据)
@Override @Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
String clientId = (String) session.getAttributes().get("clientId"); log.info("开始进入文本传输里面了");
FileMetadata metadata = objectMapper.readValue(message.getPayload(), FileMetadata.class); String clientId = getClientId(session);
if (clientId == null) {
System.err.println("无法获取客户端ID");
return;
}
// 初始化文件传输状态 try {
FileTransferState state = transferStates.get(clientId); // 解析前端发送的JSON消息
state.setFileName(metadata.getFileName()); VoicePartMessage voiceMessage = objectMapper.readValue(message.getPayload(), VoicePartMessage.class);
state.setTotalSize(metadata.getTotalSize());
state.setTotalParts(metadata.getTotalParts());
state.setOutputStream(new FileOutputStream(STORAGE_DIR + metadata.getFileName()));
System.out.println("开始接收文件: " + metadata.getFileName() + " (" + metadata.getTotalParts() + "个分片)"); // 处理语音分片
if ("voice_part".equals(voiceMessage.getType())) {
// 确认已收到元数据 processVoicePart(clientId, voiceMessage, session);
session.sendMessage(new TextMessage("{\"type\":\"metadata_ack\"}")); }
} catch (Exception e) {
System.err.println("处理消息出错: " + e.getMessage());
e.printStackTrace();
}
} }
// 处理二进制消息(文件分片) // 处理二进制消息(文件分片)
@Override @Override
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message){ protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message){
try{ log.info("开始进入文件流传输里面了");
log.info("开始-接收文件分片数据流"); log.info("获取的数据为:{}", JSONUtil.toJsonStr(message));
String clientId = (String) session.getAttributes().get("clientId");
FileTransferState state = transferStates.get(clientId);
if (state == null || state.getOutputStream() == null) {
session.sendMessage(new TextMessage("{\"type\":\"error\", \"message\":\"未收到文件元数据\"}"));
return;
}
log.info("进行中-接收文件分片数据流");
// 解析分片数据
ByteBuffer payload = message.getPayload();
// int partNumber = payload.getInt(); // 前4字节是分片编号
byte[] data = new byte[payload.remaining()];
payload.get(data);
// 写入文件
state.getOutputStream().write(data);
state.incrementReceivedParts();
// 发送进度更新每5个分片或最后一个分片
if (state.getReceivedParts() % 5 == 0 || state.getReceivedParts() == state.getTotalParts()) {
// 检查是否接收完成
if (state.getReceivedParts() == state.getTotalParts()) {
log.info("生成完整的文件-接收文件分片数据流");
completeFileTransfer(session, state, clientId);
//进行文件数据转换
//获取最终的文件结果
//把文件转成对应的文件流,返回给前端
// session.sendMessage(new BinaryMessage());
}
}
}catch (Exception e){
e.printStackTrace();
}
}
// 完成文件传输
private void completeFileTransfer(WebSocketSession session, FileTransferState state, String clientId) throws IOException {
// 关闭文件输出流
state.getOutputStream().close();
// 验证文件大小
File file = new File(STORAGE_DIR + state.getFileName());
boolean fileValid = file.length() == state.getTotalSize();
// 发送完成消息
String result = fileValid ?
"{\"type\":\"complete\", \"message\":\"文件接收完成\", \"filePath\":\"" + file.getAbsolutePath() + "\"}" :
"{\"type\":\"error\", \"message\":\"文件损坏,大小不匹配\"}";
session.sendMessage(new TextMessage(result));
System.out.println("文件接收" + (fileValid ? "完成" : "失败") + ": " + state.getFileName());
// 清理状态
transferStates.remove(clientId);
} }
@Override @Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
String clientId = (String) session.getAttributes().get("clientId"); String clientId = getClientId(session);
FileTransferState state = transferStates.remove(clientId); if (clientId != null) {
// 清理客户端资源
clientVoiceParts.remove(clientId);
clientTotalParts.remove(clientId);
clientLocks.remove(clientId);
System.out.println("客户端连接关闭: " + clientId);
}
}
// 关闭可能存在的文件流 /**
if (state != null && state.getOutputStream() != null) { * 处理语音分片
*/
private void processVoicePart(String clientId, VoicePartMessage message, WebSocketSession session) throws Exception {
ReentrantLock lock = clientLocks.get(clientId);
lock.lock(); // 加锁确保线程安全
try { try {
state.getOutputStream().close(); // 保存总分片数
// 删除未完成的文件 clientTotalParts.put(clientId, message.getTotalParts());
Files.deleteIfExists(Paths.get(STORAGE_DIR + state.getFileName()));
} catch (IOException e) { // 解码Base64数据并存储分片
e.printStackTrace(); byte[] voiceData = Base64.getDecoder().decode(message.getData());
clientVoiceParts.get(clientId).put(message.getPartNumber(), voiceData);
System.out.printf("接收客户端 %s 的分片 %d/%d%n",
clientId, message.getPartNumber() + 1, message.getTotalParts());
// 检查是否所有分片都已接收
checkAndMergeParts(clientId, session);
} finally {
lock.unlock(); // 释放锁
} }
} }
System.out.println("客户端断开连接: " + clientId);
}
@Override /**
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { * 检查是否所有分片都已接收,如果是则合并
System.err.println("传输错误: " + exception.getMessage()); */
session.close(CloseStatus.SERVER_ERROR); private void checkAndMergeParts(String clientId, WebSocketSession session) throws Exception {
Map<Integer, byte[]> parts = clientVoiceParts.get(clientId);
Integer totalParts = clientTotalParts.get(clientId);
if (parts == null || totalParts == null) {
return;
}
// 所有分片都已接收
if (parts.size() == totalParts) {
System.out.println("所有分片接收完成,开始合并: " + clientId);
// 生成唯一文件名
String fileName = clientId + "_" + System.currentTimeMillis() + ".wav";
Path outputPath = Paths.get(VOICE_STORAGE_DIR + fileName);
// 合并分片
try (FileOutputStream fos = new FileOutputStream(outputPath.toFile())) {
for (byte[] part : parts.values()) {
fos.write(part);
}
}
System.out.println("语音文件合并完成,保存路径: " + outputPath);
// 向客户端发送处理完成消息
Map<String, Object> response = new HashMap<>();
response.put("type", "complete");
response.put("message", "语音接收完成");
response.put("fileName", fileName);
session.sendMessage(new TextMessage(objectMapper.writeValueAsString(response)));
// 清理已合并的分片数据
clientVoiceParts.get(clientId).clear();
}
}
/**
* 从会话中获取客户端ID
*/
private String getClientId(WebSocketSession session) {
return (String) session.getAttributes().get("clientId");
} }
} }

View File

@@ -24,7 +24,7 @@ public class WebSocketConfig implements WebSocketConfigurer {
@Override @Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
// 注册WebSocket处理器设置路径和允许跨域 // 注册WebSocket处理器设置路径和允许跨域
registry.addHandler(voiceWebSocketHandler, "/voice-websocket123") registry.addHandler(voiceWebSocketHandler, "/voice-websocket")
.addInterceptors(voiceHandshakeInterceptor) .addInterceptors(voiceHandshakeInterceptor)
.setAllowedOrigins("*"); // 生产环境应指定具体域名而非* .setAllowedOrigins("*"); // 生产环境应指定具体域名而非*
} }