From 75edd1de7d21b61e436c4cb4b58fa4c580fdcd30 Mon Sep 17 00:00:00 2001 From: gulongcheng <474084054@qq.com> Date: Mon, 27 Oct 2025 15:43:26 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=AD=E7=BB=83=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/pom.xml | 7 ++ .../com/sdm/data/config/WebSocketConfig.java | 22 +++- .../config/WebSocketHandshakeInterceptor.java | 56 +++++++++ .../controller/WebSocketTestController.java | 83 ++++++++++++++ .../sdm/data/model/req/HandleLoadDataReq.java | 4 +- .../sdm/data/service/WebSocketService.java | 106 ++++++++++++++---- .../data/service/impl/ModelServiceImpl.java | 19 ++-- .../main/resources/static/websocketTest.html | 94 ++++++++++++++++ .../system/dao/SysUserRoleRelationMapper.java | 16 +++ .../model/entity/SysUserRoleRelation.java | 49 ++++++++ .../service/ISysUserRoleRelationService.java | 16 +++ .../impl/SysUserRoleRelationServiceImpl.java | 20 ++++ .../mapper/SysUserRoleRelationMapper.xml | 5 + 13 files changed, 457 insertions(+), 40 deletions(-) create mode 100644 data/src/main/java/com/sdm/data/config/WebSocketHandshakeInterceptor.java create mode 100644 data/src/main/java/com/sdm/data/controller/WebSocketTestController.java create mode 100644 data/src/main/resources/static/websocketTest.html create mode 100644 system/src/main/java/com/sdm/system/dao/SysUserRoleRelationMapper.java create mode 100644 system/src/main/java/com/sdm/system/model/entity/SysUserRoleRelation.java create mode 100644 system/src/main/java/com/sdm/system/service/ISysUserRoleRelationService.java create mode 100644 system/src/main/java/com/sdm/system/service/impl/SysUserRoleRelationServiceImpl.java create mode 100644 system/src/main/resources/mapper/SysUserRoleRelationMapper.xml diff --git a/data/pom.xml b/data/pom.xml index 7ace68f6..406b5c72 100644 --- a/data/pom.xml +++ b/data/pom.xml @@ -94,6 +94,13 @@ org.springframework.boot spring-boot-starter-actuator + + + + org.springframework.boot + spring-boot-starter-websocket + + diff --git a/data/src/main/java/com/sdm/data/config/WebSocketConfig.java b/data/src/main/java/com/sdm/data/config/WebSocketConfig.java index dfc54e87..34c06080 100644 --- a/data/src/main/java/com/sdm/data/config/WebSocketConfig.java +++ b/data/src/main/java/com/sdm/data/config/WebSocketConfig.java @@ -1,4 +1,3 @@ -/* package com.sdm.data.config; import org.springframework.context.annotation.Configuration; @@ -13,14 +12,29 @@ import com.sdm.data.service.WebSocketService; public class WebSocketConfig implements WebSocketConfigurer { private final WebSocketService webSocketService; + private final WebSocketHandshakeInterceptor webSocketHandshakeInterceptor; - public WebSocketConfig(WebSocketService webSocketService) { + public WebSocketConfig(WebSocketService webSocketService, + WebSocketHandshakeInterceptor webSocketHandshakeInterceptor) { this.webSocketService = webSocketService; + this.webSocketHandshakeInterceptor = webSocketHandshakeInterceptor; } @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { - registry.addHandler(webSocketService, "/ws/data-processing") + /*注册训练模型的WebSocket处理程序,并添加拦截器用于获取参数 + const ws = new WebSocket(`ws://localhost:7104/ws/data/modelTraining?userId=${userId}&modelId=${modelId}`); + */ + registry.addHandler(webSocketService, "/ws/data/modelTraining") + .addInterceptors(webSocketHandshakeInterceptor) .setAllowedOrigins("*"); + /** + // 后续注册新的WebSocket处理程序时,可以继续添加 + registry.addHandler(webSocketService, "/ws/notifications") + .setAllowedOrigins("*"); + + registry.addHandler(webSocketService, "/ws/chat") + .setAllowedOrigins("*"); + */ } -}*/ +} \ No newline at end of file diff --git a/data/src/main/java/com/sdm/data/config/WebSocketHandshakeInterceptor.java b/data/src/main/java/com/sdm/data/config/WebSocketHandshakeInterceptor.java new file mode 100644 index 00000000..4bbf208b --- /dev/null +++ b/data/src/main/java/com/sdm/data/config/WebSocketHandshakeInterceptor.java @@ -0,0 +1,56 @@ +package com.sdm.data.config; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.stereotype.Component; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; + +import java.util.Map; + +@Slf4j +@Component +public class WebSocketHandshakeInterceptor implements HandshakeInterceptor { + + @Override + public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) throws Exception { + if (request instanceof ServletServerHttpRequest) { + ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request; + + // 从请求参数中获取userId和modelId + String userIdStr = servletRequest.getServletRequest().getParameter("userId"); + String modelIdStr = servletRequest.getServletRequest().getParameter("modelId"); + + // 将参数转换为Integer并存储到attributes中,以便在WebSocketSession中使用 + if (userIdStr != null) { + try { + Integer userId = Integer.valueOf(userIdStr); + attributes.put("userId", userId); + } catch (NumberFormatException e) { + log.warn("无效的userId参数: {}", userIdStr); + } + } + if (modelIdStr != null) { + try { + Integer modelId = Integer.valueOf(modelIdStr); + attributes.put("modelId", modelId); + } catch (NumberFormatException e) { + log.warn("无效的modelId参数: {}", modelIdStr); + } + } + + log.info("WebSocket握手前,获取到参数: userId={}, modelId={}", userIdStr, modelIdStr); + } + return true; + } + + @Override + public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Exception exception) { + // 握手完成后可选的操作 + log.info("WebSocket握手完成"); + } +} \ No newline at end of file diff --git a/data/src/main/java/com/sdm/data/controller/WebSocketTestController.java b/data/src/main/java/com/sdm/data/controller/WebSocketTestController.java new file mode 100644 index 00000000..57ec76e4 --- /dev/null +++ b/data/src/main/java/com/sdm/data/controller/WebSocketTestController.java @@ -0,0 +1,83 @@ +package com.sdm.data.controller; + +import com.sdm.common.common.SdmResponse; +import com.sdm.data.service.WebSocketService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.tags.Tag; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.*; + +/** + * WebSocket测试控制器 + * 提供测试接口用于服务端主动向客户端发送消息 + */ +@RestController +@RequestMapping("/websocket/test") +@Tag(name = "WebSocket测试", description = "WebSocket功能测试接口") +public class WebSocketTestController { + + @Autowired + private WebSocketService webSocketService; + + /** + * 向所有连接的客户端发送消息 + * + * @param message 消息内容 + * @return 发送结果 + */ + @PostMapping("/sendToAll") + @Operation(summary = "向所有客户端发送消息", description = "向所有连接的客户端发送指定消息") + public SdmResponse sendToAll(@Parameter(description = "消息内容") @RequestParam String message) { + try { + webSocketService.sendToAll(message); + return SdmResponse.success("消息发送成功"); + } catch (Exception e) { + return SdmResponse.failed("消息发送失败: " + e.getMessage()); + } + } + + /** + * 向指定会话发送消息 + * + * @param sessionId 会话ID + * @param message 消息内容 + * @return 发送结果 + */ + @PostMapping("/sendToSession") + @Operation(summary = "向指定客户端发送消息", description = "向指定会话ID的客户端发送消息") + public SdmResponse sendToSession( + @Parameter(description = "会话ID") @RequestParam String sessionId, + @Parameter(description = "消息内容") @RequestParam String message) { + try { + webSocketService.sendToSession(sessionId, message); + return SdmResponse.success("消息发送成功"); + } catch (Exception e) { + return SdmResponse.failed("消息发送失败: " + e.getMessage()); + } + } + + /** + * 发送数据处理完成通知 + * + * @param userId 用户ID + * @param modelId 模型ID + * @param success 是否成功 + * @param message 消息内容 + * @return 发送结果 + */ + @PostMapping("/sendDataProcessingNotification") + @Operation(summary = "发送数据处理通知", description = "向指定用户模型发送数据处理完成的通知消息") + public SdmResponse sendDataProcessingNotification( + @Parameter(description = "用户ID") @RequestParam Integer userId, + @Parameter(description = "模型ID") @RequestParam Integer modelId, + @Parameter(description = "是否成功") @RequestParam boolean success, + @Parameter(description = "消息内容") @RequestParam String message) { + try { + webSocketService.sendDataProcessingNotification(userId, modelId, success, message); + return SdmResponse.success("通知发送成功"); + } catch (Exception e) { + return SdmResponse.failed("通知发送失败: " + e.getMessage()); + } + } +} \ No newline at end of file diff --git a/data/src/main/java/com/sdm/data/model/req/HandleLoadDataReq.java b/data/src/main/java/com/sdm/data/model/req/HandleLoadDataReq.java index e6c019c1..bd000a74 100644 --- a/data/src/main/java/com/sdm/data/model/req/HandleLoadDataReq.java +++ b/data/src/main/java/com/sdm/data/model/req/HandleLoadDataReq.java @@ -13,6 +13,6 @@ public class HandleLoadDataReq { @Schema(description = "训练模型id") private Integer trainingModelId; - /* @Schema(description = "WebSocket会话ID") - private String sessionId;*/ + @Schema(description = "WebSocket会话userId") + private Integer userId; } \ No newline at end of file diff --git a/data/src/main/java/com/sdm/data/service/WebSocketService.java b/data/src/main/java/com/sdm/data/service/WebSocketService.java index a22d91dc..2a4abc96 100644 --- a/data/src/main/java/com/sdm/data/service/WebSocketService.java +++ b/data/src/main/java/com/sdm/data/service/WebSocketService.java @@ -1,49 +1,93 @@ -/* package com.sdm.data.service; import com.alibaba.fastjson2.JSONObject; import lombok.extern.slf4j.Slf4j; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; import org.springframework.stereotype.Service; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; import java.io.IOException; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @Slf4j @Service public class WebSocketService extends TextWebSocketHandler { - // 存储所有活跃的WebSocket会话 private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + // 存储用户模型ID与会话ID的映射关系 (userId_modelId -> sessionId) + private final ConcurrentHashMap userModelSessionMap = new ConcurrentHashMap<>(); + @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { sessions.put(session.getId(), session); - log.info("WebSocket连接已建立,会话ID: {}", session.getId()); + + // 从会话属性中获取userId和modelId + Object userIdObj = session.getAttributes().get("userId"); + Object modelIdObj = session.getAttributes().get("modelId"); + + Integer userId = null; + Integer modelId = null; + + // 类型安全检查并转换 + if (userIdObj instanceof Integer) { + userId = (Integer) userIdObj; + } else if (userIdObj != null) { + log.warn("userId不是预期的Integer类型,实际类型为: {}", userIdObj.getClass().getName()); + } + + if (modelIdObj instanceof Integer) { + modelId = (Integer) modelIdObj; + } else if (modelIdObj != null) { + log.warn("modelId不是预期的Integer类型,实际类型为: {}", modelIdObj.getClass().getName()); + } + + // 只有当userId和modelId都存在时才建立映射关系 + if (userId != null && modelId != null) { + String key = userId + "_" + modelId; + userModelSessionMap.put(key, session.getId()); + log.info("用户模型 {} 已注册到会话 {}", key, session.getId()); + } + + log.info("WebSocket连接已建立,会话ID: {}, 用户ID: {}, 模型ID: {}", session.getId(), userId, modelId); // 发送连接成功的消息 JSONObject message = new JSONObject(); message.put("type", "connection"); message.put("status", "connected"); + message.put("sessionId", session.getId()); session.sendMessage(new TextMessage(message.toJSONString())); } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { - sessions.remove(session.getId()); - log.info("WebSocket连接已关闭,会话ID: {},状态: {}", session.getId(), status); + // 清理所有映射关系 + String sessionId = session.getId(); + sessions.remove(sessionId); + + // 从用户模型映射中移除 + userModelSessionMap.values().removeIf(id -> id.equals(sessionId)); + + log.info("WebSocket连接已关闭,会话ID: {},状态: {}", sessionId, status); } - */ -/** + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + String payload = message.getPayload(); + log.info("收到来自会话 {} 的消息: {}", session.getId(), payload); + } + + /** * 向所有连接的客户端发送消息 * @param message 消息内容 - *//* - + */ public void sendToAll(String message) { sessions.values().forEach(session -> { try { @@ -55,14 +99,12 @@ public class WebSocketService extends TextWebSocketHandler { } }); } - - */ -/** + + /** * 向指定会话发送消息 * @param sessionId 会话ID * @param message 消息内容 - *//* - + */ public void sendToSession(String sessionId, String message) { WebSocketSession session = sessions.get(sessionId); if (session != null && session.isOpen()) { @@ -74,20 +116,36 @@ public class WebSocketService extends TextWebSocketHandler { } } - */ -/** - * 发送数据处理完成的通知 - * @param sessionId 会话ID + /** + * 向指定用户的模型发送消息 + * @param userId 用户ID + * @param modelId 模型ID + * @param message 消息内容 + */ + public void sendToUserModel(Integer userId, Integer modelId, String message) { + String key = userId + "_" + modelId; + String sessionId = userModelSessionMap.get(key); + if (sessionId != null) { + sendToSession(sessionId, message); + } else { + log.warn("未找到用户 {} 的模型 {} 的WebSocket会话", userId, modelId); + } + } + + /** + * 发送训练模型数据处理完成的通知 + * @param userId 用户ID + * @param modelId 模型ID * @param success 是否成功 * @param message 消息内容 - *//* - - public void sendDataProcessingNotification(String sessionId, boolean success, String message) { + */ + public void sendDataProcessingNotification(Integer userId, Integer modelId, boolean success, String message) { JSONObject notification = new JSONObject(); - notification.put("type", "dataProcessing"); + notification.put("type", "modelDataProcessing"); + notification.put("modelId", modelId); notification.put("success", success); notification.put("message", message); - sendToSession(sessionId, notification.toJSONString()); + sendToUserModel(userId, modelId, notification.toJSONString()); } -}*/ +} \ No newline at end of file diff --git a/data/src/main/java/com/sdm/data/service/impl/ModelServiceImpl.java b/data/src/main/java/com/sdm/data/service/impl/ModelServiceImpl.java index 7d114514..adff0050 100644 --- a/data/src/main/java/com/sdm/data/service/impl/ModelServiceImpl.java +++ b/data/src/main/java/com/sdm/data/service/impl/ModelServiceImpl.java @@ -120,6 +120,9 @@ public class ModelServiceImpl implements IModelService { @Autowired ITrainingModelAlgorithmParamService trainingModelAlgorithmParamService; + @Autowired + WebSocketService webSocketService; + @Override @Transactional(rollbackFor = Exception.class) @@ -204,8 +207,7 @@ public class ModelServiceImpl implements IModelService { trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId).set(TrainingModel::getHandleStatus, "处理中").update(); // 异步调用Python脚本处理数据 - // String sessionId = handleLoadDataReq.getSessionId(); // 假设请求中包含sessionId - processDataAsync(DATA_HANDLER_PYTHON_SCRIPT_PATH, paramJsonPath, trainingModelId, null); + processDataAsync(DATA_HANDLER_PYTHON_SCRIPT_PATH, paramJsonPath, trainingModelId, handleLoadDataReq.getUserId(),handleLoadDataReq.getTrainingModelId()); return SdmResponse.success("数据处理中"); } catch (Exception e) { log.error("处理上传数据失败", e); @@ -266,9 +268,10 @@ public class ModelServiceImpl implements IModelService { * * @param pythonScriptPath Python脚本路径 * @param paramJsonPath 参数文件路径 - * @param sessionId WebSocket会话ID + * @param userId 用户ID + * @param modelId 模型ID */ - private void processDataAsync(String pythonScriptPath, String paramJsonPath, Integer trainingModelId, String sessionId) { + private void processDataAsync(String pythonScriptPath, String paramJsonPath, Integer trainingModelId, Integer userId, Integer modelId) { new Thread(() -> { try { // 调用Python脚本处理数据 @@ -347,16 +350,12 @@ public class ModelServiceImpl implements IModelService { log.info("训练模型ID: {}数据处理完成,处理结果文件ID:{},总耗时: {} ms", trainingModelId, fileId, duration); // 通过WebSocket通知前端处理完成 - /*if (sessionId != null) { - webSocketService.sendDataProcessingNotification(sessionId, true, "数据处理完成"); - }*/ + webSocketService.sendDataProcessingNotification(userId, modelId,true, "数据处理完成"); } catch (Exception e) { trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId).set(TrainingModel::getHandleStatus, "失败").update(); log.error("异步处理数据失败", e); // 通过WebSocket通知前端处理失败 - /*if (sessionId != null) { - webSocketService.sendDataProcessingNotification(sessionId, false, "数据处理失败: " + e.getMessage()); - }*/ + webSocketService.sendDataProcessingNotification(userId, modelId,false, "数据处理失败: " + e.getMessage()); } }).start(); } diff --git a/data/src/main/resources/static/websocketTest.html b/data/src/main/resources/static/websocketTest.html new file mode 100644 index 00000000..7676b1e4 --- /dev/null +++ b/data/src/main/resources/static/websocketTest.html @@ -0,0 +1,94 @@ + + + + + WebSocket 测试 + + +

WebSocket 测试页面

+
+ + + + + + + + + +
+
+ + + + \ No newline at end of file diff --git a/system/src/main/java/com/sdm/system/dao/SysUserRoleRelationMapper.java b/system/src/main/java/com/sdm/system/dao/SysUserRoleRelationMapper.java new file mode 100644 index 00000000..e6817d37 --- /dev/null +++ b/system/src/main/java/com/sdm/system/dao/SysUserRoleRelationMapper.java @@ -0,0 +1,16 @@ +package com.sdm.system.dao; + +import com.sdm.system.model.entity.SysUserRoleRelation; +import com.baomidou.mybatisplus.core.mapper.BaseMapper; + +/** + *

+ * 用户角色关联表 Mapper 接口 + *

+ * + * @author author + * @since 2025-10-23 + */ +public interface SysUserRoleRelationMapper extends BaseMapper { + +} diff --git a/system/src/main/java/com/sdm/system/model/entity/SysUserRoleRelation.java b/system/src/main/java/com/sdm/system/model/entity/SysUserRoleRelation.java new file mode 100644 index 00000000..abb4c85c --- /dev/null +++ b/system/src/main/java/com/sdm/system/model/entity/SysUserRoleRelation.java @@ -0,0 +1,49 @@ +package com.sdm.system.model.entity; + +import com.baomidou.mybatisplus.annotation.TableField; +import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import java.time.LocalDateTime; +import java.io.Serializable; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.Accessors; + +/** + *

+ * 用户角色关联表 + *

+ * + * @author author + * @since 2025-10-23 + */ +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(chain = true) +@TableName("sys_user_role_relation") +@ApiModel(value="SysUserRoleRelation对象", description="用户角色关联表") +public class SysUserRoleRelation implements Serializable { + + private static final long serialVersionUID = 1L; + + @ApiModelProperty(value = "关联记录主键ID") + @TableId(value = "id", type = IdType.AUTO) + private Integer id; + + @ApiModelProperty(value = "用户ID(关联用户表主键)") + @TableField("user_id") + private Integer userId; + + @ApiModelProperty(value = "角色ID(关联角色表主键)") + @TableField("role_id") + private Integer roleId; + + @ApiModelProperty(value = "关联关系创建时间") + @TableField("add_time") + private LocalDateTime addTime; + + +} diff --git a/system/src/main/java/com/sdm/system/service/ISysUserRoleRelationService.java b/system/src/main/java/com/sdm/system/service/ISysUserRoleRelationService.java new file mode 100644 index 00000000..60321281 --- /dev/null +++ b/system/src/main/java/com/sdm/system/service/ISysUserRoleRelationService.java @@ -0,0 +1,16 @@ +package com.sdm.system.service; + +import com.sdm.system.model.entity.SysUserRoleRelation; +import com.baomidou.mybatisplus.extension.service.IService; + +/** + *

+ * 用户角色关联表 服务类 + *

+ * + * @author author + * @since 2025-10-23 + */ +public interface ISysUserRoleRelationService extends IService { + +} diff --git a/system/src/main/java/com/sdm/system/service/impl/SysUserRoleRelationServiceImpl.java b/system/src/main/java/com/sdm/system/service/impl/SysUserRoleRelationServiceImpl.java new file mode 100644 index 00000000..de36598c --- /dev/null +++ b/system/src/main/java/com/sdm/system/service/impl/SysUserRoleRelationServiceImpl.java @@ -0,0 +1,20 @@ +package com.sdm.system.service.impl; + +import com.sdm.system.model.entity.SysUserRoleRelation; +import com.sdm.system.dao.SysUserRoleRelationMapper; +import com.sdm.system.service.ISysUserRoleRelationService; +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import org.springframework.stereotype.Service; + +/** + *

+ * 用户角色关联表 服务实现类 + *

+ * + * @author author + * @since 2025-10-23 + */ +@Service +public class SysUserRoleRelationServiceImpl extends ServiceImpl implements ISysUserRoleRelationService { + +} diff --git a/system/src/main/resources/mapper/SysUserRoleRelationMapper.xml b/system/src/main/resources/mapper/SysUserRoleRelationMapper.xml new file mode 100644 index 00000000..8ed68221 --- /dev/null +++ b/system/src/main/resources/mapper/SysUserRoleRelationMapper.xml @@ -0,0 +1,5 @@ + + + + +