训练模型

This commit is contained in:
2025-10-27 15:43:26 +08:00
parent c81033cd80
commit 75edd1de7d
13 changed files with 457 additions and 40 deletions

View File

@@ -94,6 +94,13 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-actuator</artifactId>
</dependency>
<!-- WebSocket -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
</dependencies>
<dependencyManagement>
<dependencies>

View File

@@ -1,4 +1,3 @@
/*
package com.sdm.data.config;
import org.springframework.context.annotation.Configuration;
@@ -13,14 +12,29 @@ import com.sdm.data.service.WebSocketService;
public class WebSocketConfig implements WebSocketConfigurer {
private final WebSocketService webSocketService;
private final WebSocketHandshakeInterceptor webSocketHandshakeInterceptor;
public WebSocketConfig(WebSocketService webSocketService) {
public WebSocketConfig(WebSocketService webSocketService,
WebSocketHandshakeInterceptor webSocketHandshakeInterceptor) {
this.webSocketService = webSocketService;
this.webSocketHandshakeInterceptor = webSocketHandshakeInterceptor;
}
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(webSocketService, "/ws/data-processing")
/*注册训练模型的WebSocket处理程序并添加拦截器用于获取参数
const ws = new WebSocket(`ws://localhost:7104/ws/data/modelTraining?userId=${userId}&modelId=${modelId}`);
*/
registry.addHandler(webSocketService, "/ws/data/modelTraining")
.addInterceptors(webSocketHandshakeInterceptor)
.setAllowedOrigins("*");
/**
// 后续注册新的WebSocket处理程序时可以继续添加
registry.addHandler(webSocketService, "/ws/notifications")
.setAllowedOrigins("*");
registry.addHandler(webSocketService, "/ws/chat")
.setAllowedOrigins("*");
*/
}
}*/
}

View File

@@ -0,0 +1,56 @@
package com.sdm.data.config;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import java.util.Map;
@Slf4j
@Component
public class WebSocketHandshakeInterceptor implements HandshakeInterceptor {
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
if (request instanceof ServletServerHttpRequest) {
ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
// 从请求参数中获取userId和modelId
String userIdStr = servletRequest.getServletRequest().getParameter("userId");
String modelIdStr = servletRequest.getServletRequest().getParameter("modelId");
// 将参数转换为Integer并存储到attributes中以便在WebSocketSession中使用
if (userIdStr != null) {
try {
Integer userId = Integer.valueOf(userIdStr);
attributes.put("userId", userId);
} catch (NumberFormatException e) {
log.warn("无效的userId参数: {}", userIdStr);
}
}
if (modelIdStr != null) {
try {
Integer modelId = Integer.valueOf(modelIdStr);
attributes.put("modelId", modelId);
} catch (NumberFormatException e) {
log.warn("无效的modelId参数: {}", modelIdStr);
}
}
log.info("WebSocket握手前获取到参数: userId={}, modelId={}", userIdStr, modelIdStr);
}
return true;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception exception) {
// 握手完成后可选的操作
log.info("WebSocket握手完成");
}
}

View File

@@ -0,0 +1,83 @@
package com.sdm.data.controller;
import com.sdm.common.common.SdmResponse;
import com.sdm.data.service.WebSocketService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
/**
* WebSocket测试控制器
* 提供测试接口用于服务端主动向客户端发送消息
*/
@RestController
@RequestMapping("/websocket/test")
@Tag(name = "WebSocket测试", description = "WebSocket功能测试接口")
public class WebSocketTestController {
@Autowired
private WebSocketService webSocketService;
/**
* 向所有连接的客户端发送消息
*
* @param message 消息内容
* @return 发送结果
*/
@PostMapping("/sendToAll")
@Operation(summary = "向所有客户端发送消息", description = "向所有连接的客户端发送指定消息")
public SdmResponse sendToAll(@Parameter(description = "消息内容") @RequestParam String message) {
try {
webSocketService.sendToAll(message);
return SdmResponse.success("消息发送成功");
} catch (Exception e) {
return SdmResponse.failed("消息发送失败: " + e.getMessage());
}
}
/**
* 向指定会话发送消息
*
* @param sessionId 会话ID
* @param message 消息内容
* @return 发送结果
*/
@PostMapping("/sendToSession")
@Operation(summary = "向指定客户端发送消息", description = "向指定会话ID的客户端发送消息")
public SdmResponse sendToSession(
@Parameter(description = "会话ID") @RequestParam String sessionId,
@Parameter(description = "消息内容") @RequestParam String message) {
try {
webSocketService.sendToSession(sessionId, message);
return SdmResponse.success("消息发送成功");
} catch (Exception e) {
return SdmResponse.failed("消息发送失败: " + e.getMessage());
}
}
/**
* 发送数据处理完成通知
*
* @param userId 用户ID
* @param modelId 模型ID
* @param success 是否成功
* @param message 消息内容
* @return 发送结果
*/
@PostMapping("/sendDataProcessingNotification")
@Operation(summary = "发送数据处理通知", description = "向指定用户模型发送数据处理完成的通知消息")
public SdmResponse sendDataProcessingNotification(
@Parameter(description = "用户ID") @RequestParam Integer userId,
@Parameter(description = "模型ID") @RequestParam Integer modelId,
@Parameter(description = "是否成功") @RequestParam boolean success,
@Parameter(description = "消息内容") @RequestParam String message) {
try {
webSocketService.sendDataProcessingNotification(userId, modelId, success, message);
return SdmResponse.success("通知发送成功");
} catch (Exception e) {
return SdmResponse.failed("通知发送失败: " + e.getMessage());
}
}
}

View File

@@ -13,6 +13,6 @@ public class HandleLoadDataReq {
@Schema(description = "训练模型id")
private Integer trainingModelId;
/* @Schema(description = "WebSocket会话ID")
private String sessionId;*/
@Schema(description = "WebSocket会话userId")
private Integer userId;
}

View File

@@ -1,49 +1,93 @@
/*
package com.sdm.data.service;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Service;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
@Service
public class WebSocketService extends TextWebSocketHandler {
// 存储所有活跃的WebSocket会话
private final ConcurrentHashMap<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
// 存储用户模型ID与会话ID的映射关系 (userId_modelId -> sessionId)
private final ConcurrentHashMap<String, String> userModelSessionMap = new ConcurrentHashMap<>();
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
sessions.put(session.getId(), session);
log.info("WebSocket连接已建立会话ID: {}", session.getId());
// 从会话属性中获取userId和modelId
Object userIdObj = session.getAttributes().get("userId");
Object modelIdObj = session.getAttributes().get("modelId");
Integer userId = null;
Integer modelId = null;
// 类型安全检查并转换
if (userIdObj instanceof Integer) {
userId = (Integer) userIdObj;
} else if (userIdObj != null) {
log.warn("userId不是预期的Integer类型实际类型为: {}", userIdObj.getClass().getName());
}
if (modelIdObj instanceof Integer) {
modelId = (Integer) modelIdObj;
} else if (modelIdObj != null) {
log.warn("modelId不是预期的Integer类型实际类型为: {}", modelIdObj.getClass().getName());
}
// 只有当userId和modelId都存在时才建立映射关系
if (userId != null && modelId != null) {
String key = userId + "_" + modelId;
userModelSessionMap.put(key, session.getId());
log.info("用户模型 {} 已注册到会话 {}", key, session.getId());
}
log.info("WebSocket连接已建立会话ID: {}, 用户ID: {}, 模型ID: {}", session.getId(), userId, modelId);
// 发送连接成功的消息
JSONObject message = new JSONObject();
message.put("type", "connection");
message.put("status", "connected");
message.put("sessionId", session.getId());
session.sendMessage(new TextMessage(message.toJSONString()));
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
sessions.remove(session.getId());
log.info("WebSocket连接已关闭会话ID: {},状态: {}", session.getId(), status);
// 清理所有映射关系
String sessionId = session.getId();
sessions.remove(sessionId);
// 从用户模型映射中移除
userModelSessionMap.values().removeIf(id -> id.equals(sessionId));
log.info("WebSocket连接已关闭会话ID: {},状态: {}", sessionId, status);
}
*/
/**
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
String payload = message.getPayload();
log.info("收到来自会话 {} 的消息: {}", session.getId(), payload);
}
/**
* 向所有连接的客户端发送消息
* @param message 消息内容
*//*
*/
public void sendToAll(String message) {
sessions.values().forEach(session -> {
try {
@@ -55,14 +99,12 @@ public class WebSocketService extends TextWebSocketHandler {
}
});
}
*/
/**
/**
* 向指定会话发送消息
* @param sessionId 会话ID
* @param message 消息内容
*//*
*/
public void sendToSession(String sessionId, String message) {
WebSocketSession session = sessions.get(sessionId);
if (session != null && session.isOpen()) {
@@ -74,20 +116,36 @@ public class WebSocketService extends TextWebSocketHandler {
}
}
*/
/**
* 发送数据处理完成的通知
* @param sessionId 会话ID
/**
* 向指定用户的模型发送消息
* @param userId 用户ID
* @param modelId 模型ID
* @param message 消息内容
*/
public void sendToUserModel(Integer userId, Integer modelId, String message) {
String key = userId + "_" + modelId;
String sessionId = userModelSessionMap.get(key);
if (sessionId != null) {
sendToSession(sessionId, message);
} else {
log.warn("未找到用户 {} 的模型 {} 的WebSocket会话", userId, modelId);
}
}
/**
* 发送训练模型数据处理完成的通知
* @param userId 用户ID
* @param modelId 模型ID
* @param success 是否成功
* @param message 消息内容
*//*
public void sendDataProcessingNotification(String sessionId, boolean success, String message) {
*/
public void sendDataProcessingNotification(Integer userId, Integer modelId, boolean success, String message) {
JSONObject notification = new JSONObject();
notification.put("type", "dataProcessing");
notification.put("type", "modelDataProcessing");
notification.put("modelId", modelId);
notification.put("success", success);
notification.put("message", message);
sendToSession(sessionId, notification.toJSONString());
sendToUserModel(userId, modelId, notification.toJSONString());
}
}*/
}

View File

@@ -120,6 +120,9 @@ public class ModelServiceImpl implements IModelService {
@Autowired
ITrainingModelAlgorithmParamService trainingModelAlgorithmParamService;
@Autowired
WebSocketService webSocketService;
@Override
@Transactional(rollbackFor = Exception.class)
@@ -204,8 +207,7 @@ public class ModelServiceImpl implements IModelService {
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId).set(TrainingModel::getHandleStatus, "处理中").update();
// 异步调用Python脚本处理数据
// String sessionId = handleLoadDataReq.getSessionId(); // 假设请求中包含sessionId
processDataAsync(DATA_HANDLER_PYTHON_SCRIPT_PATH, paramJsonPath, trainingModelId, null);
processDataAsync(DATA_HANDLER_PYTHON_SCRIPT_PATH, paramJsonPath, trainingModelId, handleLoadDataReq.getUserId(),handleLoadDataReq.getTrainingModelId());
return SdmResponse.success("数据处理中");
} catch (Exception e) {
log.error("处理上传数据失败", e);
@@ -266,9 +268,10 @@ public class ModelServiceImpl implements IModelService {
*
* @param pythonScriptPath Python脚本路径
* @param paramJsonPath 参数文件路径
* @param sessionId WebSocket会话ID
* @param userId 用户ID
* @param modelId 模型ID
*/
private void processDataAsync(String pythonScriptPath, String paramJsonPath, Integer trainingModelId, String sessionId) {
private void processDataAsync(String pythonScriptPath, String paramJsonPath, Integer trainingModelId, Integer userId, Integer modelId) {
new Thread(() -> {
try {
// 调用Python脚本处理数据
@@ -347,16 +350,12 @@ public class ModelServiceImpl implements IModelService {
log.info("训练模型ID: {}数据处理完成,处理结果文件ID:{},总耗时: {} ms", trainingModelId, fileId, duration);
// 通过WebSocket通知前端处理完成
/*if (sessionId != null) {
webSocketService.sendDataProcessingNotification(sessionId, true, "数据处理完成");
}*/
webSocketService.sendDataProcessingNotification(userId, modelId,true, "数据处理完成");
} catch (Exception e) {
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId).set(TrainingModel::getHandleStatus, "失败").update();
log.error("异步处理数据失败", e);
// 通过WebSocket通知前端处理失败
/*if (sessionId != null) {
webSocketService.sendDataProcessingNotification(sessionId, false, "数据处理失败: " + e.getMessage());
}*/
webSocketService.sendDataProcessingNotification(userId, modelId,false, "数据处理失败: " + e.getMessage());
}
}).start();
}

View File

@@ -0,0 +1,94 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>WebSocket 测试</title>
</head>
<body>
<h1>WebSocket 测试页面</h1>
<div>
<label for="userId">用户ID:</label>
<input type="number" id="userId" value="1001" />
<label for="modelId">模型ID:</label>
<input type="number" id="modelId" value="2025" />
<button onclick="connect()">连接</button>
<button onclick="sendPing()">发送 Ping</button>
<button onclick="sendSubscribe()">发送订阅</button>
<button onclick="sendUnknown()">发送未知类型</button>
<button onclick="disconnect()">断开连接</button>
</div>
<div id="messages" style="height: 300px; overflow: auto; border: 1px solid #ccc; margin-top: 10px;"></div>
<script>
let ws;
function connect() {
const userId = document.getElementById("userId").value;
const modelId = document.getElementById("modelId").value;
// 修改WebSocket连接地址通过网关连接到后端服务
ws = new WebSocket(`ws://127.0.0.1:7100/simulation/data/ws/data/modelTraining?userId=${userId}&modelId=${modelId}`);
ws.onopen = function(event) {
addMessage("连接已建立");
};
ws.onmessage = function(event) {
addMessage("收到消息: " + event.data);
const message = JSON.parse(event.data);
addMessage("消息类型: " + message.type);
};
ws.onclose = function(event) {
addMessage("连接已关闭");
};
ws.onerror = function(error) {
addMessage("发生错误: " + error);
};
}
function sendPing() {
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({
"type": "ping"
}));
addMessage("已发送 ping 消息");
}
}
function sendSubscribe() {
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({
"type": "subscribe",
"topic": "dataProcessing"
}));
addMessage("已发送订阅消息");
}
}
function sendUnknown() {
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({
"type": "unknown",
"data": "test"
}));
addMessage("已发送未知类型消息");
}
}
function disconnect() {
if (ws) {
ws.close();
}
}
function addMessage(message) {
const messagesDiv = document.getElementById("messages");
const time = new Date().toLocaleTimeString();
messagesDiv.innerHTML += `<p>[${time}] ${message}</p>`;
messagesDiv.scrollTop = messagesDiv.scrollHeight;
}
</script>
</body>
</html>