训练模型
This commit is contained in:
@@ -94,6 +94,13 @@
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-actuator</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- WebSocket -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-websocket</artifactId>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
/*
|
||||
package com.sdm.data.config;
|
||||
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
@@ -13,14 +12,29 @@ import com.sdm.data.service.WebSocketService;
|
||||
public class WebSocketConfig implements WebSocketConfigurer {
|
||||
|
||||
private final WebSocketService webSocketService;
|
||||
private final WebSocketHandshakeInterceptor webSocketHandshakeInterceptor;
|
||||
|
||||
public WebSocketConfig(WebSocketService webSocketService) {
|
||||
public WebSocketConfig(WebSocketService webSocketService,
|
||||
WebSocketHandshakeInterceptor webSocketHandshakeInterceptor) {
|
||||
this.webSocketService = webSocketService;
|
||||
this.webSocketHandshakeInterceptor = webSocketHandshakeInterceptor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
|
||||
registry.addHandler(webSocketService, "/ws/data-processing")
|
||||
/*注册训练模型的WebSocket处理程序,并添加拦截器用于获取参数
|
||||
const ws = new WebSocket(`ws://localhost:7104/ws/data/modelTraining?userId=${userId}&modelId=${modelId}`);
|
||||
*/
|
||||
registry.addHandler(webSocketService, "/ws/data/modelTraining")
|
||||
.addInterceptors(webSocketHandshakeInterceptor)
|
||||
.setAllowedOrigins("*");
|
||||
/**
|
||||
// 后续注册新的WebSocket处理程序时,可以继续添加
|
||||
registry.addHandler(webSocketService, "/ws/notifications")
|
||||
.setAllowedOrigins("*");
|
||||
|
||||
registry.addHandler(webSocketService, "/ws/chat")
|
||||
.setAllowedOrigins("*");
|
||||
*/
|
||||
}
|
||||
}*/
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package com.sdm.data.config;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.server.ServerHttpRequest;
|
||||
import org.springframework.http.server.ServerHttpResponse;
|
||||
import org.springframework.http.server.ServletServerHttpRequest;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.server.HandshakeInterceptor;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class WebSocketHandshakeInterceptor implements HandshakeInterceptor {
|
||||
|
||||
@Override
|
||||
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
|
||||
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
|
||||
if (request instanceof ServletServerHttpRequest) {
|
||||
ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
|
||||
|
||||
// 从请求参数中获取userId和modelId
|
||||
String userIdStr = servletRequest.getServletRequest().getParameter("userId");
|
||||
String modelIdStr = servletRequest.getServletRequest().getParameter("modelId");
|
||||
|
||||
// 将参数转换为Integer并存储到attributes中,以便在WebSocketSession中使用
|
||||
if (userIdStr != null) {
|
||||
try {
|
||||
Integer userId = Integer.valueOf(userIdStr);
|
||||
attributes.put("userId", userId);
|
||||
} catch (NumberFormatException e) {
|
||||
log.warn("无效的userId参数: {}", userIdStr);
|
||||
}
|
||||
}
|
||||
if (modelIdStr != null) {
|
||||
try {
|
||||
Integer modelId = Integer.valueOf(modelIdStr);
|
||||
attributes.put("modelId", modelId);
|
||||
} catch (NumberFormatException e) {
|
||||
log.warn("无效的modelId参数: {}", modelIdStr);
|
||||
}
|
||||
}
|
||||
|
||||
log.info("WebSocket握手前,获取到参数: userId={}, modelId={}", userIdStr, modelIdStr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
|
||||
WebSocketHandler wsHandler, Exception exception) {
|
||||
// 握手完成后可选的操作
|
||||
log.info("WebSocket握手完成");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package com.sdm.data.controller;
|
||||
|
||||
import com.sdm.common.common.SdmResponse;
|
||||
import com.sdm.data.service.WebSocketService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
/**
|
||||
* WebSocket测试控制器
|
||||
* 提供测试接口用于服务端主动向客户端发送消息
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping("/websocket/test")
|
||||
@Tag(name = "WebSocket测试", description = "WebSocket功能测试接口")
|
||||
public class WebSocketTestController {
|
||||
|
||||
@Autowired
|
||||
private WebSocketService webSocketService;
|
||||
|
||||
/**
|
||||
* 向所有连接的客户端发送消息
|
||||
*
|
||||
* @param message 消息内容
|
||||
* @return 发送结果
|
||||
*/
|
||||
@PostMapping("/sendToAll")
|
||||
@Operation(summary = "向所有客户端发送消息", description = "向所有连接的客户端发送指定消息")
|
||||
public SdmResponse sendToAll(@Parameter(description = "消息内容") @RequestParam String message) {
|
||||
try {
|
||||
webSocketService.sendToAll(message);
|
||||
return SdmResponse.success("消息发送成功");
|
||||
} catch (Exception e) {
|
||||
return SdmResponse.failed("消息发送失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 向指定会话发送消息
|
||||
*
|
||||
* @param sessionId 会话ID
|
||||
* @param message 消息内容
|
||||
* @return 发送结果
|
||||
*/
|
||||
@PostMapping("/sendToSession")
|
||||
@Operation(summary = "向指定客户端发送消息", description = "向指定会话ID的客户端发送消息")
|
||||
public SdmResponse sendToSession(
|
||||
@Parameter(description = "会话ID") @RequestParam String sessionId,
|
||||
@Parameter(description = "消息内容") @RequestParam String message) {
|
||||
try {
|
||||
webSocketService.sendToSession(sessionId, message);
|
||||
return SdmResponse.success("消息发送成功");
|
||||
} catch (Exception e) {
|
||||
return SdmResponse.failed("消息发送失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送数据处理完成通知
|
||||
*
|
||||
* @param userId 用户ID
|
||||
* @param modelId 模型ID
|
||||
* @param success 是否成功
|
||||
* @param message 消息内容
|
||||
* @return 发送结果
|
||||
*/
|
||||
@PostMapping("/sendDataProcessingNotification")
|
||||
@Operation(summary = "发送数据处理通知", description = "向指定用户模型发送数据处理完成的通知消息")
|
||||
public SdmResponse sendDataProcessingNotification(
|
||||
@Parameter(description = "用户ID") @RequestParam Integer userId,
|
||||
@Parameter(description = "模型ID") @RequestParam Integer modelId,
|
||||
@Parameter(description = "是否成功") @RequestParam boolean success,
|
||||
@Parameter(description = "消息内容") @RequestParam String message) {
|
||||
try {
|
||||
webSocketService.sendDataProcessingNotification(userId, modelId, success, message);
|
||||
return SdmResponse.success("通知发送成功");
|
||||
} catch (Exception e) {
|
||||
return SdmResponse.failed("通知发送失败: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,6 @@ public class HandleLoadDataReq {
|
||||
@Schema(description = "训练模型id")
|
||||
private Integer trainingModelId;
|
||||
|
||||
/* @Schema(description = "WebSocket会话ID")
|
||||
private String sessionId;*/
|
||||
@Schema(description = "WebSocket会话userId")
|
||||
private Integer userId;
|
||||
}
|
||||
@@ -1,49 +1,93 @@
|
||||
/*
|
||||
package com.sdm.data.service;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.server.ServerHttpRequest;
|
||||
import org.springframework.http.server.ServerHttpResponse;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.socket.CloseStatus;
|
||||
import org.springframework.web.socket.TextMessage;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.handler.TextWebSocketHandler;
|
||||
import org.springframework.web.socket.server.HandshakeInterceptor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class WebSocketService extends TextWebSocketHandler {
|
||||
|
||||
// 存储所有活跃的WebSocket会话
|
||||
private final ConcurrentHashMap<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
|
||||
|
||||
// 存储用户模型ID与会话ID的映射关系 (userId_modelId -> sessionId)
|
||||
private final ConcurrentHashMap<String, String> userModelSessionMap = new ConcurrentHashMap<>();
|
||||
|
||||
@Override
|
||||
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
||||
sessions.put(session.getId(), session);
|
||||
log.info("WebSocket连接已建立,会话ID: {}", session.getId());
|
||||
|
||||
// 从会话属性中获取userId和modelId
|
||||
Object userIdObj = session.getAttributes().get("userId");
|
||||
Object modelIdObj = session.getAttributes().get("modelId");
|
||||
|
||||
Integer userId = null;
|
||||
Integer modelId = null;
|
||||
|
||||
// 类型安全检查并转换
|
||||
if (userIdObj instanceof Integer) {
|
||||
userId = (Integer) userIdObj;
|
||||
} else if (userIdObj != null) {
|
||||
log.warn("userId不是预期的Integer类型,实际类型为: {}", userIdObj.getClass().getName());
|
||||
}
|
||||
|
||||
if (modelIdObj instanceof Integer) {
|
||||
modelId = (Integer) modelIdObj;
|
||||
} else if (modelIdObj != null) {
|
||||
log.warn("modelId不是预期的Integer类型,实际类型为: {}", modelIdObj.getClass().getName());
|
||||
}
|
||||
|
||||
// 只有当userId和modelId都存在时才建立映射关系
|
||||
if (userId != null && modelId != null) {
|
||||
String key = userId + "_" + modelId;
|
||||
userModelSessionMap.put(key, session.getId());
|
||||
log.info("用户模型 {} 已注册到会话 {}", key, session.getId());
|
||||
}
|
||||
|
||||
log.info("WebSocket连接已建立,会话ID: {}, 用户ID: {}, 模型ID: {}", session.getId(), userId, modelId);
|
||||
|
||||
// 发送连接成功的消息
|
||||
JSONObject message = new JSONObject();
|
||||
message.put("type", "connection");
|
||||
message.put("status", "connected");
|
||||
message.put("sessionId", session.getId());
|
||||
session.sendMessage(new TextMessage(message.toJSONString()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
|
||||
sessions.remove(session.getId());
|
||||
log.info("WebSocket连接已关闭,会话ID: {},状态: {}", session.getId(), status);
|
||||
// 清理所有映射关系
|
||||
String sessionId = session.getId();
|
||||
sessions.remove(sessionId);
|
||||
|
||||
// 从用户模型映射中移除
|
||||
userModelSessionMap.values().removeIf(id -> id.equals(sessionId));
|
||||
|
||||
log.info("WebSocket连接已关闭,会话ID: {},状态: {}", sessionId, status);
|
||||
}
|
||||
|
||||
*/
|
||||
/**
|
||||
@Override
|
||||
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
|
||||
String payload = message.getPayload();
|
||||
log.info("收到来自会话 {} 的消息: {}", session.getId(), payload);
|
||||
}
|
||||
|
||||
/**
|
||||
* 向所有连接的客户端发送消息
|
||||
* @param message 消息内容
|
||||
*//*
|
||||
|
||||
*/
|
||||
public void sendToAll(String message) {
|
||||
sessions.values().forEach(session -> {
|
||||
try {
|
||||
@@ -55,14 +99,12 @@ public class WebSocketService extends TextWebSocketHandler {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
*/
|
||||
/**
|
||||
|
||||
/**
|
||||
* 向指定会话发送消息
|
||||
* @param sessionId 会话ID
|
||||
* @param message 消息内容
|
||||
*//*
|
||||
|
||||
*/
|
||||
public void sendToSession(String sessionId, String message) {
|
||||
WebSocketSession session = sessions.get(sessionId);
|
||||
if (session != null && session.isOpen()) {
|
||||
@@ -74,20 +116,36 @@ public class WebSocketService extends TextWebSocketHandler {
|
||||
}
|
||||
}
|
||||
|
||||
*/
|
||||
/**
|
||||
* 发送数据处理完成的通知
|
||||
* @param sessionId 会话ID
|
||||
/**
|
||||
* 向指定用户的模型发送消息
|
||||
* @param userId 用户ID
|
||||
* @param modelId 模型ID
|
||||
* @param message 消息内容
|
||||
*/
|
||||
public void sendToUserModel(Integer userId, Integer modelId, String message) {
|
||||
String key = userId + "_" + modelId;
|
||||
String sessionId = userModelSessionMap.get(key);
|
||||
if (sessionId != null) {
|
||||
sendToSession(sessionId, message);
|
||||
} else {
|
||||
log.warn("未找到用户 {} 的模型 {} 的WebSocket会话", userId, modelId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送训练模型数据处理完成的通知
|
||||
* @param userId 用户ID
|
||||
* @param modelId 模型ID
|
||||
* @param success 是否成功
|
||||
* @param message 消息内容
|
||||
*//*
|
||||
|
||||
public void sendDataProcessingNotification(String sessionId, boolean success, String message) {
|
||||
*/
|
||||
public void sendDataProcessingNotification(Integer userId, Integer modelId, boolean success, String message) {
|
||||
JSONObject notification = new JSONObject();
|
||||
notification.put("type", "dataProcessing");
|
||||
notification.put("type", "modelDataProcessing");
|
||||
notification.put("modelId", modelId);
|
||||
notification.put("success", success);
|
||||
notification.put("message", message);
|
||||
|
||||
sendToSession(sessionId, notification.toJSONString());
|
||||
sendToUserModel(userId, modelId, notification.toJSONString());
|
||||
}
|
||||
}*/
|
||||
}
|
||||
@@ -120,6 +120,9 @@ public class ModelServiceImpl implements IModelService {
|
||||
@Autowired
|
||||
ITrainingModelAlgorithmParamService trainingModelAlgorithmParamService;
|
||||
|
||||
@Autowired
|
||||
WebSocketService webSocketService;
|
||||
|
||||
|
||||
@Override
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
@@ -204,8 +207,7 @@ public class ModelServiceImpl implements IModelService {
|
||||
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId).set(TrainingModel::getHandleStatus, "处理中").update();
|
||||
|
||||
// 异步调用Python脚本处理数据
|
||||
// String sessionId = handleLoadDataReq.getSessionId(); // 假设请求中包含sessionId
|
||||
processDataAsync(DATA_HANDLER_PYTHON_SCRIPT_PATH, paramJsonPath, trainingModelId, null);
|
||||
processDataAsync(DATA_HANDLER_PYTHON_SCRIPT_PATH, paramJsonPath, trainingModelId, handleLoadDataReq.getUserId(),handleLoadDataReq.getTrainingModelId());
|
||||
return SdmResponse.success("数据处理中");
|
||||
} catch (Exception e) {
|
||||
log.error("处理上传数据失败", e);
|
||||
@@ -266,9 +268,10 @@ public class ModelServiceImpl implements IModelService {
|
||||
*
|
||||
* @param pythonScriptPath Python脚本路径
|
||||
* @param paramJsonPath 参数文件路径
|
||||
* @param sessionId WebSocket会话ID
|
||||
* @param userId 用户ID
|
||||
* @param modelId 模型ID
|
||||
*/
|
||||
private void processDataAsync(String pythonScriptPath, String paramJsonPath, Integer trainingModelId, String sessionId) {
|
||||
private void processDataAsync(String pythonScriptPath, String paramJsonPath, Integer trainingModelId, Integer userId, Integer modelId) {
|
||||
new Thread(() -> {
|
||||
try {
|
||||
// 调用Python脚本处理数据
|
||||
@@ -347,16 +350,12 @@ public class ModelServiceImpl implements IModelService {
|
||||
log.info("训练模型ID: {}数据处理完成,处理结果文件ID:{},总耗时: {} ms", trainingModelId, fileId, duration);
|
||||
|
||||
// 通过WebSocket通知前端处理完成
|
||||
/*if (sessionId != null) {
|
||||
webSocketService.sendDataProcessingNotification(sessionId, true, "数据处理完成");
|
||||
}*/
|
||||
webSocketService.sendDataProcessingNotification(userId, modelId,true, "数据处理完成");
|
||||
} catch (Exception e) {
|
||||
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId).set(TrainingModel::getHandleStatus, "失败").update();
|
||||
log.error("异步处理数据失败", e);
|
||||
// 通过WebSocket通知前端处理失败
|
||||
/*if (sessionId != null) {
|
||||
webSocketService.sendDataProcessingNotification(sessionId, false, "数据处理失败: " + e.getMessage());
|
||||
}*/
|
||||
webSocketService.sendDataProcessingNotification(userId, modelId,false, "数据处理失败: " + e.getMessage());
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
|
||||
94
data/src/main/resources/static/websocketTest.html
Normal file
94
data/src/main/resources/static/websocketTest.html
Normal file
@@ -0,0 +1,94 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>WebSocket 测试</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>WebSocket 测试页面</h1>
|
||||
<div>
|
||||
<label for="userId">用户ID:</label>
|
||||
<input type="number" id="userId" value="1001" />
|
||||
<label for="modelId">模型ID:</label>
|
||||
<input type="number" id="modelId" value="2025" />
|
||||
<button onclick="connect()">连接</button>
|
||||
<button onclick="sendPing()">发送 Ping</button>
|
||||
<button onclick="sendSubscribe()">发送订阅</button>
|
||||
<button onclick="sendUnknown()">发送未知类型</button>
|
||||
<button onclick="disconnect()">断开连接</button>
|
||||
</div>
|
||||
<div id="messages" style="height: 300px; overflow: auto; border: 1px solid #ccc; margin-top: 10px;"></div>
|
||||
|
||||
<script>
|
||||
let ws;
|
||||
|
||||
function connect() {
|
||||
const userId = document.getElementById("userId").value;
|
||||
const modelId = document.getElementById("modelId").value;
|
||||
|
||||
// 修改WebSocket连接地址,通过网关连接到后端服务
|
||||
ws = new WebSocket(`ws://127.0.0.1:7100/simulation/data/ws/data/modelTraining?userId=${userId}&modelId=${modelId}`);
|
||||
|
||||
ws.onopen = function(event) {
|
||||
addMessage("连接已建立");
|
||||
};
|
||||
|
||||
ws.onmessage = function(event) {
|
||||
addMessage("收到消息: " + event.data);
|
||||
const message = JSON.parse(event.data);
|
||||
addMessage("消息类型: " + message.type);
|
||||
};
|
||||
|
||||
ws.onclose = function(event) {
|
||||
addMessage("连接已关闭");
|
||||
};
|
||||
|
||||
ws.onerror = function(error) {
|
||||
addMessage("发生错误: " + error);
|
||||
};
|
||||
}
|
||||
|
||||
function sendPing() {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({
|
||||
"type": "ping"
|
||||
}));
|
||||
addMessage("已发送 ping 消息");
|
||||
}
|
||||
}
|
||||
|
||||
function sendSubscribe() {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({
|
||||
"type": "subscribe",
|
||||
"topic": "dataProcessing"
|
||||
}));
|
||||
addMessage("已发送订阅消息");
|
||||
}
|
||||
}
|
||||
|
||||
function sendUnknown() {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({
|
||||
"type": "unknown",
|
||||
"data": "test"
|
||||
}));
|
||||
addMessage("已发送未知类型消息");
|
||||
}
|
||||
}
|
||||
|
||||
function disconnect() {
|
||||
if (ws) {
|
||||
ws.close();
|
||||
}
|
||||
}
|
||||
|
||||
function addMessage(message) {
|
||||
const messagesDiv = document.getElementById("messages");
|
||||
const time = new Date().toLocaleTimeString();
|
||||
messagesDiv.innerHTML += `<p>[${time}] ${message}</p>`;
|
||||
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user