websocket 生成完整的文件
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
package com.vetti.socket;
|
||||
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.socket.config.annotation.EnableWebSocket;
|
||||
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
|
||||
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
|
||||
import org.springframework.web.socket.server.HandshakeInterceptor;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import java.util.Map;
|
||||
|
||||
@Configuration
|
||||
@EnableWebSocket
|
||||
public class FileReceiverConfig implements WebSocketConfigurer {
|
||||
|
||||
@Bean
|
||||
public FileReceiverWebSocketHandler fileReceiverWebSocketHandler() {
|
||||
return new FileReceiverWebSocketHandler();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
|
||||
registry.addHandler(fileReceiverWebSocketHandler(), "/voice-websocket")
|
||||
.addInterceptors(handshakeInterceptor())
|
||||
.setAllowedOrigins("*"); // 生产环境需替换为具体域名
|
||||
}
|
||||
|
||||
@Bean
|
||||
public HandshakeInterceptor handshakeInterceptor() {
|
||||
return new HandshakeInterceptor() {
|
||||
@Override
|
||||
public boolean beforeHandshake(org.springframework.http.server.ServerHttpRequest request,
|
||||
org.springframework.http.server.ServerHttpResponse response,
|
||||
org.springframework.web.socket.WebSocketHandler wsHandler,
|
||||
Map<String, Object> attributes) throws Exception {
|
||||
if (request instanceof org.springframework.http.server.ServletServerHttpRequest) {
|
||||
HttpServletRequest servletRequest =
|
||||
((org.springframework.http.server.ServletServerHttpRequest) request).getServletRequest();
|
||||
String clientId = servletRequest.getParameter("clientId");
|
||||
if (clientId != null && !clientId.isEmpty()) {
|
||||
attributes.put("clientId", clientId);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterHandshake(org.springframework.http.server.ServerHttpRequest request,
|
||||
org.springframework.http.server.ServerHttpResponse response,
|
||||
org.springframework.web.socket.WebSocketHandler wsHandler,
|
||||
Exception exception) {
|
||||
// 握手后处理
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package com.vetti.socket;
|
||||
|
||||
import com.vetti.socket.vo.FileMetadata;
|
||||
import com.vetti.socket.vo.FileTransferState;
|
||||
import org.springframework.web.socket.*;
|
||||
import org.springframework.web.socket.handler.TextWebSocketHandler;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.io.*;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.file.*;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class FileReceiverWebSocketHandler extends TextWebSocketHandler {
|
||||
// 存储客户端的文件传输状态:clientId -> FileTransferState
|
||||
private final Map<String, FileTransferState> transferStates = new ConcurrentHashMap<>();
|
||||
private final ObjectMapper objectMapper = new ObjectMapper();
|
||||
private static final String STORAGE_DIR = "received_files/";
|
||||
|
||||
// 初始化存储目录
|
||||
static {
|
||||
try {
|
||||
Files.createDirectories(Paths.get(STORAGE_DIR));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("无法创建文件存储目录", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
||||
String clientId = (String) session.getAttributes().get("clientId");
|
||||
transferStates.put(clientId, new FileTransferState());
|
||||
System.out.println("客户端连接: " + clientId);
|
||||
}
|
||||
|
||||
// 处理文本消息(文件元数据)
|
||||
@Override
|
||||
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
|
||||
String clientId = (String) session.getAttributes().get("clientId");
|
||||
FileMetadata metadata = objectMapper.readValue(message.getPayload(), FileMetadata.class);
|
||||
|
||||
// 初始化文件传输状态
|
||||
FileTransferState state = transferStates.get(clientId);
|
||||
state.setFileName(metadata.getFileName());
|
||||
state.setTotalSize(metadata.getTotalSize());
|
||||
state.setTotalParts(metadata.getTotalParts());
|
||||
state.setOutputStream(new FileOutputStream(STORAGE_DIR + metadata.getFileName()));
|
||||
|
||||
System.out.println("开始接收文件: " + metadata.getFileName() + " (" + metadata.getTotalParts() + "个分片)");
|
||||
|
||||
// 确认已收到元数据
|
||||
session.sendMessage(new TextMessage("{\"type\":\"metadata_ack\"}"));
|
||||
}
|
||||
|
||||
// 处理二进制消息(文件分片)
|
||||
@Override
|
||||
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message){
|
||||
try{
|
||||
System.out.println("开始-接收文件分片数据流");
|
||||
String clientId = (String) session.getAttributes().get("clientId");
|
||||
FileTransferState state = transferStates.get(clientId);
|
||||
|
||||
if (state == null || state.getOutputStream() == null) {
|
||||
session.sendMessage(new TextMessage("{\"type\":\"error\", \"message\":\"未收到文件元数据\"}"));
|
||||
return;
|
||||
}
|
||||
System.out.println("进行中-接收文件分片数据流");
|
||||
// 解析分片数据
|
||||
ByteBuffer payload = message.getPayload();
|
||||
// int partNumber = payload.getInt(); // 前4字节是分片编号
|
||||
byte[] data = new byte[payload.remaining()];
|
||||
payload.get(data);
|
||||
|
||||
// 写入文件
|
||||
state.getOutputStream().write(data);
|
||||
state.incrementReceivedParts();
|
||||
|
||||
// 发送进度更新(每5个分片或最后一个分片)
|
||||
if (state.getReceivedParts() % 5 == 0 || state.getReceivedParts() == state.getTotalParts()) {
|
||||
// 检查是否接收完成
|
||||
if (state.getReceivedParts() == state.getTotalParts()) {
|
||||
System.out.println("生成完整的文件-接收文件分片数据流");
|
||||
completeFileTransfer(session, state, clientId);
|
||||
//进行文件数据转换
|
||||
|
||||
//获取最终的文件结果
|
||||
|
||||
//把文件转成对应的文件流,返回给前端
|
||||
// session.sendMessage(new BinaryMessage());
|
||||
}
|
||||
}
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// 完成文件传输
|
||||
private void completeFileTransfer(WebSocketSession session, FileTransferState state, String clientId) throws IOException {
|
||||
// 关闭文件输出流
|
||||
state.getOutputStream().close();
|
||||
|
||||
// 验证文件大小
|
||||
File file = new File(STORAGE_DIR + state.getFileName());
|
||||
boolean fileValid = file.length() == state.getTotalSize();
|
||||
|
||||
// 发送完成消息
|
||||
String result = fileValid ?
|
||||
"{\"type\":\"complete\", \"message\":\"文件接收完成\", \"filePath\":\"" + file.getAbsolutePath() + "\"}" :
|
||||
"{\"type\":\"error\", \"message\":\"文件损坏,大小不匹配\"}";
|
||||
session.sendMessage(new TextMessage(result));
|
||||
|
||||
System.out.println("文件接收" + (fileValid ? "完成" : "失败") + ": " + state.getFileName());
|
||||
|
||||
// 清理状态
|
||||
transferStates.remove(clientId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
|
||||
String clientId = (String) session.getAttributes().get("clientId");
|
||||
FileTransferState state = transferStates.remove(clientId);
|
||||
|
||||
// 关闭可能存在的文件流
|
||||
if (state != null && state.getOutputStream() != null) {
|
||||
try {
|
||||
state.getOutputStream().close();
|
||||
// 删除未完成的文件
|
||||
Files.deleteIfExists(Paths.get(STORAGE_DIR + state.getFileName()));
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
System.out.println("客户端断开连接: " + clientId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
|
||||
System.err.println("传输错误: " + exception.getMessage());
|
||||
session.close(CloseStatus.SERVER_ERROR);
|
||||
}
|
||||
}
|
||||
@@ -81,15 +81,16 @@ public class VoiceWebSocketHandler extends TextWebSocketHandler {
|
||||
|
||||
// 发送进度更新(每5个分片或最后一个分片)
|
||||
if (state.getReceivedParts() % 5 == 0 || state.getReceivedParts() == state.getTotalParts()) {
|
||||
double progress = (double) state.getReceivedParts() / state.getTotalParts() * 100;
|
||||
session.sendMessage(new TextMessage(
|
||||
"{\"type\":\"progress\", \"progress\":" + progress + "}"
|
||||
));
|
||||
|
||||
// 检查是否接收完成
|
||||
if (state.getReceivedParts() == state.getTotalParts()) {
|
||||
log.info("生成完整的文件-接收文件分片数据流");
|
||||
completeFileTransfer(session, state, clientId);
|
||||
//进行文件数据转换
|
||||
|
||||
//获取最终的文件结果
|
||||
|
||||
//把文件转成对应的文件流,返回给前端
|
||||
// session.sendMessage(new BinaryMessage());
|
||||
}
|
||||
}
|
||||
}catch (Exception e){
|
||||
@@ -101,11 +102,9 @@ public class VoiceWebSocketHandler extends TextWebSocketHandler {
|
||||
private void completeFileTransfer(WebSocketSession session, FileTransferState state, String clientId) throws IOException {
|
||||
// 关闭文件输出流
|
||||
state.getOutputStream().close();
|
||||
|
||||
// 验证文件大小
|
||||
File file = new File(STORAGE_DIR + state.getFileName());
|
||||
boolean fileValid = file.length() == state.getTotalSize();
|
||||
|
||||
// 发送完成消息
|
||||
String result = fileValid ?
|
||||
"{\"type\":\"complete\", \"message\":\"文件接收完成\", \"filePath\":\"" + file.getAbsolutePath() + "\"}" :
|
||||
|
||||
@@ -24,7 +24,7 @@ public class WebSocketConfig implements WebSocketConfigurer {
|
||||
@Override
|
||||
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
|
||||
// 注册WebSocket处理器,设置路径和允许跨域
|
||||
registry.addHandler(voiceWebSocketHandler, "/voice-websocket")
|
||||
registry.addHandler(voiceWebSocketHandler, "/voice-websocket123")
|
||||
.addInterceptors(voiceHandshakeInterceptor)
|
||||
.setAllowedOrigins("*"); // 生产环境应指定具体域名而非*
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user