diff --git a/system/src/main/java/com/sdm/system/controller/WebSocketServer.java b/system/src/main/java/com/sdm/system/controller/WebSocketServer.java index c4dbf51c..22a1c8ba 100644 --- a/system/src/main/java/com/sdm/system/controller/WebSocketServer.java +++ b/system/src/main/java/com/sdm/system/controller/WebSocketServer.java @@ -11,30 +11,31 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; @Slf4j @Component @ServerEndpoint("/sysWs") public class WebSocketServer { - // 在线会话池(安全) - public static final Map SESSION_POOL = new ConcurrentHashMap<>(); + // ==================== 修复后:一个用户对应多个Session ==================== + public static final Map> USER_SESSION_POOL = new ConcurrentHashMap<>(); // ==================== 连接建立 ==================== @OnOpen public void onOpen(Session session) { try { - // 获取 userId Long userId = ThreadLocalContext.getUserId(); - if(Objects.isNull(userId)){ + if (Objects.isNull(userId)) { throw new RuntimeException("userId不能是null"); } - // ========================================== - // userId 存入 Session - // ========================================== + session.getUserProperties().put("userid", userId); - SESSION_POOL.put(userId, session); - log.info("用户[{}]连接成功,当前在线:{}", userId, SESSION_POOL.size()); + + // 如果用户不存在,创建新List;存在则追加 + USER_SESSION_POOL.computeIfAbsent(userId, k -> new CopyOnWriteArrayList<>()).add(session); + + log.info("用户[{}]新设备连接成功,当前设备数:{}", userId, USER_SESSION_POOL.get(userId).size()); } catch (Exception e) { log.error("WebSocket 连接失败", e); } @@ -43,16 +44,21 @@ public class WebSocketServer { // ==================== 关闭连接 ==================== @OnClose public void onClose(Session session) { - // 从 Session 取 userId,绝对安全 - Long userId = (Long) session.getUserProperties().get("userid"); - if (userId != null) { - SESSION_POOL.remove(userId); - log.info("用户[{}]断开连接,当前在线:{}", userId, SESSION_POOL.size()); - } try { - if (session.isOpen()) session.close(); - } catch (Exception ignored) { - log.info("用户[{}]断开连接异常-->", userId, ignored); + Long userId = (Long) session.getUserProperties().get("userid"); + if (userId != null) { + List sessions = USER_SESSION_POOL.get(userId); + if (sessions != null) { + sessions.remove(session); + // 如果该用户没有设备了,移除用户key + if (sessions.isEmpty()) { + USER_SESSION_POOL.remove(userId); + } + } + log.info("用户[{}]设备断开,剩余设备数:{}", userId, sessions == null ? 0 : sessions.size()); + } + } catch (Exception e) { + log.error("关闭连接异常", e); } } @@ -60,6 +66,17 @@ public class WebSocketServer { @OnMessage public void onMessage(String message, Session session) { Long userId = (Long) session.getUserProperties().get("userid"); + + // ==================== 心跳处理:收到 ping 返回 pong,后期拓展交互数据的场景 ==================== + if ("ping".equals(message)) { + try { + session.getBasicRemote().sendText("pong"); + return; + } catch (Exception e) { + log.error("心跳pong发送失败", e); + } + } + log.info("收到用户[{}]消息:{}", userId, message); } @@ -70,53 +87,48 @@ public class WebSocketServer { log.error("用户[{}]异常:", userId, error); } - /** - * 推送消息给单一用户 - */ + // ==================== 推送消息给用户(所有设备都能收到) ==================== public static void sendToUser(Long userId, String message) { - Session session = SESSION_POOL.get(userId); - if (session != null && session.isOpen()) { - try { - session.getBasicRemote().sendText(message); - } catch (Exception e) { - log.error("用户[{}]消息发送失败-->", userId,e); - throw new RuntimeException("消息发送异常"+e.getMessage()); - } - }else { - log.warn("用户[{}]消息发送失败,session null or close", userId); - throw new RuntimeException("消息发送失败,session关闭"); + List sessions = USER_SESSION_POOL.get(userId); + if (sessions == null || sessions.isEmpty()) { + log.warn("用户[{}]不在线", userId); + return; } - } - /** - * 广播消息 -> 给所有在线用户发送 - */ - public static void sendToAllUser(String message) { - // 遍历所有在线用户 - for (Map.Entry entry : SESSION_POOL.entrySet()) { - Long userId = entry.getKey(); - Session session = entry.getValue(); + for (Session session : sessions) { try { - if (session != null && session.isOpen()) { + if (session.isOpen()) { session.getBasicRemote().sendText(message); - log.info("广播消息发送给用户[{}]成功", userId); } } catch (Exception e) { - // 发送失败就移除无效连接 - log.error("广播消息发送给用户[{}]失败-->", userId, e); + log.error("发送消息给用户[{}]失败", userId, e); } } } - public static List getAllWsUsers() { - List userIds = new ArrayList<>(); - for (Map.Entry entry : SESSION_POOL.entrySet()) { + // ==================== 广播 ==================== + public static void sendToAllUser(String message) { + for (Map.Entry> entry : USER_SESSION_POOL.entrySet()) { Long userId = entry.getKey(); - if(!Objects.isNull(userId)){ - userIds.add(userId); + List sessions = entry.getValue(); + for (Session session : sessions) { + try { + if (session.isOpen()) { + session.getBasicRemote().sendText(message); + } + } catch (Exception e) { + log.error("广播发送失败 用户[{}]", userId, e); + } } } - return userIds; + } + + /** + * 获取所有在线的用户ID(自动去重,一个用户只返回一次) + */ + public static List getAllWsUsers() { + // 直接返回 MAP 的所有 Key,天然去重(一个 userId 只存一次) + return new ArrayList<>(USER_SESSION_POOL.keySet()); } } \ No newline at end of file