SpringBoot + WebSocket 弱网保活机制:App 切后台断线重连,消息精准补发不丢失!

做过 IM 聊天、在线游戏、实时推送这类应用的都知道,WebSocket 连接稳定性是个大难题:

  • 用户坐地铁过隧道,信号时断时续
  • App 切到后台,系统会杀掉 WebSocket 连接
  • 公司网络需要认证,一断开就再也连不上
  • 用户 Dormant 模式网络切换,IP 都变了

更头疼的是消息丢失问题:

  • 用户发了一条消息,服务器还没收到就断线了
  • 服务器推送的消息,客户端还没收到就断线了
  • 重连后不知道哪些消息漏收了,只能全部同步

今天我们来聊聊 WebSocket 弱网环境下的保活机制,确保 App 在各种网络环境下都能保持连接,消息不丢失。

为什么 WebSocket 在弱网环境下这么脆弱?

先来分析一下 WebSocket 连接断线的常见场景:

1. App 切后台

前台运行:WebSocket 保持连接 ←→ 服务器正常通信
    ↓
切到后台:系统可能 kill WebSocket 连接
    ↓
切回前台:需要重新建立连接,但之前的消息已经丢了

2. 网络切换

WiFi 连接:WebSocket 通过 IP1 建立连接
    ↓
切换到 4G:IP 变成 IP2,原有连接直接失效
    ↓
重连成功:但服务器不知道 client 收到了哪些消息

3. 弱网环境

发送消息 → 网络抖动 → 超时判定为失败
    ↓
重试发送 → 实际上服务器已经收到了
    ↓
结果:服务器收到两条相同的消息

4. 服务器重启

用户连接服务器 Node1,正在收发消息
    ↓
Node1 重启,负载均衡切到 Node2
    ↓
Node2 没有该用户的 session 状态,不知道哪些消息已读

整体架构设计

我们的 WebSocket 弱网保活方案由以下几个核心组件构成:

  1. WebSocketSessionManager:会话管理器,管理所有 WebSocket 连接
  2. HeartbeatScheduler:心跳调度器,定时发送心跳检测连接状态
  3. MessageQueue:消息队列,未确认的消息暂存
  4. MessageResender:消息补发器,断线重连后补发未确认消息
  5. SequenceManager:序列号管理器,确保消息不丢不重
  6. ReconnectionManager:重连管理器,控制重连策略

让我们看看如何在 SpringBoot 中实现这套系统:

1. WebSocket 配置和会话管理

首先定义 WebSocket 配置和会话管理:

@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

    @Autowired
    private WebSocketHandler webSocketHandler;

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(webSocketHandler, "/ws/{userId}")
                .setAllowedOrigins("*");
    }
}
@Component
@Slf4j
public class WebSocketSessionManager {

    private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();

    private final Map<String, AtomicLong> sessionHeartbeat = new ConcurrentHashMap<>();

    public void registerSession(String userId, WebSocketSession session) {
        sessions.put(userId, session);
        sessionHeartbeat.put(userId, new AtomicLong(System.currentTimeMillis()));
        log.info("WebSocket 会话注册: userId={}, sessionId={}", userId, session.getId());
    }

    public void removeSession(String userId) {
        sessions.remove(userId);
        sessionHeartbeat.remove(userId);
        log.info("WebSocket 会话移除: userId={}", userId);
    }

    public WebSocketSession getSession(String userId) {
        return sessions.get(userId);
    }

    public boolean isConnected(String userId) {
        WebSocketSession session = sessions.get(userId);
        return session != null && session.isOpen();
    }

    public void updateHeartbeat(String userId) {
        AtomicLong heartbeat = sessionHeartbeat.get(userId);
        if (heartbeat != null) {
            heartbeat.set(System.currentTimeMillis());
        }
    }

    public Set<String> getAllConnectedUsers() {
        return new HashSet<>(sessions.keySet());
    }

    public void sendToUser(String userId, TextMessage message) {
        WebSocketSession session = sessions.get(userId);
        if (session != null && session.isOpen()) {
            try {
                session.sendMessage(message);
                log.debug("发送消息给用户成功: userId={}, message={}", userId, message.getPayload());
            } catch (IOException e) {
                log.error("发送消息失败: userId={}", userId, e);
            }
        } else {
            log.warn("用户未连接: userId={}", userId);
        }
    }

    public void broadcast(TextMessage message) {
        for (WebSocketSession session : sessions.values()) {
            if (session.isOpen()) {
                try {
                    session.sendMessage(message);
                } catch (IOException e) {
                    log.error("广播消息失败: sessionId={}", session.getId(), e);
                }
            }
        }
    }
}

2. 序列号管理器

确保消息不丢不重的核心组件:

@Component
@Slf4j
public class SequenceManager {

    private final Map<String, AtomicLong> sendSequence = new ConcurrentHashMap<>();

    private final Map<String, AtomicLong> receiveSequence = new ConcurrentHashMap<>();

    private final Map<String, Long> lastAckedSequence = new ConcurrentHashMap<>();

    public Long generateSendSequence(String userId) {
        AtomicLong sequence = sendSequence.computeIfAbsent(userId, k -> new AtomicLong(0));
        return sequence.incrementAndGet();
    }

    public Long getNextExpectedReceiveSequence(String userId) {
        AtomicLong sequence = receiveSequence.computeIfAbsent(userId, k -> new AtomicLong(1L));
        return sequence.get();
    }

    public boolean validateReceiveSequence(String userId, Long sequenceNo) {
        if (sequenceNo == null) {
            return false;
        }

        AtomicLong expected = receiveSequence.computeIfAbsent(userId, k -> new AtomicLong(1L));
        long current = expected.get();

        if (sequenceNo.equals(current)) {
            expected.incrementAndGet();
            return true;
        }

        if (sequenceNo < current) {
            log.debug("收到过期消息: userId={}, expected={}, actual={}", userId, current, sequenceNo);
            return false;
        }

        log.warn("收到未来消息,序列号不连续: userId={}, expected={}, actual={}", userId, current, sequenceNo);
        return false;
    }

    public void markAcked(String userId, Long sequenceNo) {
        if (sequenceNo != null) {
            lastAckedSequence.put(userId, sequenceNo);
            log.debug("消息已确认: userId={}, sequenceNo={}", userId, sequenceNo);
        }
    }

    public Long getLastAckedSequence(String userId) {
        return lastAckedSequence.get(userId);
    }

    public void resetSendSequence(String userId) {
        sendSequence.remove(userId);
        log.info("重置发送序列号: userId={}", userId);
    }

    public void resetReceiveSequence(String userId) {
        receiveSequence.remove(userId);
        log.info("重置接收序列号: userId={}", userId);
    }
}

3. 消息队列管理

暂存未确认的消息:

@Component
@Slf4j
public class MessageQueueManager {

    private final Map<String, ConcurrentLinkedQueue<PendingMessage>> pendingMessages = new ConcurrentHashMap<>();

    private final Map<String, AtomicInteger> pendingCount = new ConcurrentHashMap<>();

    private final int maxQueueSize = 1000;

    public void addPendingMessage(String userId, ChatMessage message) {
        ConcurrentLinkedQueue<PendingMessage> queue = pendingMessages.computeIfAbsent(
            userId, k -> new ConcurrentLinkedQueue<>());

        if (queue.size() >= maxQueueSize) {
            log.warn("消息队列已满,移除最早的消息: userId={}, size={}", userId, queue.size());
            queue.poll();
        }

        queue.add(new PendingMessage(message, System.currentTimeMillis()));
        pendingCount.computeIfAbsent(userId, k -> new AtomicInteger(0)).incrementAndGet();

        log.debug("添加待确认消息: userId={}, sequenceNo={}, queueSize={}",
            userId, message.getSequenceNo(), queue.size());
    }

    public List<ChatMessage> getUnackedMessages(String userId) {
        ConcurrentLinkedQueue<PendingMessage> queue = pendingMessages.get(userId);
        if (queue == null || queue.isEmpty()) {
            return Collections.emptyList();
        }

        List<ChatMessage> unacked = new ArrayList<>();
        for (PendingMessage pm : queue) {
            unacked.add(pm.getMessage());
        }
        return unacked;
    }

    public void removePendingMessage(String userId, Long sequenceNo) {
        ConcurrentLinkedQueue<PendingMessage> queue = pendingMessages.get(userId);
        if (queue == null) {
            return;
        }

        Iterator<PendingMessage> iterator = queue.iterator();
        while (iterator.hasNext()) {
            PendingMessage pm = iterator.next();
            if (pm.getMessage().getSequenceNo().equals(sequenceNo)) {
                iterator.remove();
                AtomicInteger count = pendingCount.get(userId);
                if (count != null && count.get() > 0) {
                    count.decrementAndGet();
                }
                log.debug("移除已确认消息: userId={}, sequenceNo={}", userId, sequenceNo);
                break;
            }
        }
    }

    public void clearPendingMessages(String userId) {
        pendingMessages.remove(userId);
        pendingCount.remove(userId);
        log.info("清空待确认消息队列: userId={}", userId);
    }

    public int getPendingCount(String userId) {
        AtomicInteger count = pendingCount.get(userId);
        return count == null ? 0 : count.get();
    }

    @Data
    @AllArgsConstructor
    public static class PendingMessage {
        private ChatMessage message;
        private long addTime;
    }
}

4. 消息补发器

断线重连后补发未确认的消息:

@Component
@Slf4j
public class MessageResender {

    @Autowired
    private WebSocketSessionManager sessionManager;

    @Autowired
    private MessageQueueManager messageQueue;

    @Autowired
    private SequenceManager sequenceManager;

    @Autowired
    private ChatMessageMapper messageMapper;

    private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(2);

    public void resendUnackedMessages(String userId) {
        WebSocketSession session = sessionManager.getSession(userId);
        if (session == null || !session.isOpen()) {
            log.warn("用户未连接,无法补发消息: userId={}", userId);
            return;
        }

        List<ChatMessage> unackedMessages = messageQueue.getUnackedMessages(userId);
        if (unackedMessages.isEmpty()) {
            log.info("没有需要补发的消息: userId={}", userId);
            return;
        }

        log.info("开始补发未确认消息: userId={}, count={}", userId, unackedMessages.size());

        for (ChatMessage message : unackedMessages) {
            try {
                TextMessage textMessage = new TextMessage(JSON.toJSONString(message));
                session.sendMessage(textMessage);
                log.debug("补发消息成功: userId={}, sequenceNo={}", userId, message.getSequenceNo());
            } catch (IOException e) {
                log.error("补发消息失败: userId={}, sequenceNo={}", userId, message.getSequenceNo(), e);
                break;
            }
        }
    }

    public void syncMissedMessages(String userId, Long lastReceivedSeq) {
        log.info("同步遗漏消息: userId={}, lastReceivedSeq={}", userId, lastReceivedSeq);

        Long nextExpected = sequenceManager.getNextExpectedReceiveSequence(userId);

        if (lastReceivedSeq < nextExpected - 1) {
            log.warn("客户端序列号异常: userId={}, clientSeq={}, expected={}",
                userId, lastReceivedSeq, nextExpected);
        }

        List<ChatMessage> missedMessages = messageMapper.selectBySequenceRange(
            userId, lastReceivedSeq + 1, nextExpected - 1);

        if (missedMessages.isEmpty()) {
            log.info("没有遗漏消息: userId={}", userId);
            return;
        }

        WebSocketSession session = sessionManager.getSession(userId);
        if (session == null || !session.isOpen()) {
            log.warn("用户未连接,无法同步消息: userId={}", userId);
            return;
        }

        log.info("同步遗漏消息: userId={}, missedCount={}", userId, missedMessages.size());

        for (ChatMessage message : missedMessages) {
            try {
                TextMessage textMessage = new TextMessage(JSON.toJSONString(message));
                session.sendMessage(textMessage);
                log.debug("同步消息成功: userId={}, sequenceNo={}", userId, message.getSequenceNo());
            } catch (IOException e) {
                log.error("同步消息失败: userId={}, sequenceNo={}", userId, message.getSequenceNo(), e);
                break;
            }
        }
    }

    public void startScheduledResend(String userId) {
        scheduler.scheduleWithFixedDelay(() -> {
            try {
                resendUnackedMessages(userId);
            } catch (Exception e) {
                log.error("定时补发消息异常: userId={}", userId, e);
            }
        }, 5, 5, TimeUnit.SECONDS);
    }
}

5. 重连管理器

控制重连策略:

@Component
@Slf4j
public class ReconnectionManager {

    private final Map<String, ReconnectContext> reconnectContexts = new ConcurrentHashMap<>();

    private final Map<String, AtomicInteger> reconnectAttempts = new ConcurrentHashMap<>();

    private final int maxAttempts = 10;

    private final long baseDelayMs = 1000;

    private final long maxDelayMs = 30000;

    public boolean shouldReconnect(String userId) {
        ReconnectContext context = reconnectContexts.get(userId);

        if (context == null) {
            return true;
        }

        if (context.getState() == ReconnectState.CONNECTED) {
            reconnectAttempts.remove(userId);
            return false;
        }

        if (context.getAttempts() >= maxAttempts) {
            log.warn("重连次数已达上限: userId={}, attempts={}", userId, context.getAttempts());
            return false;
        }

        long elapsed = System.currentTimeMillis() - context.getStartTime();
        if (elapsed > maxDelayMs * 2) {
            log.info("重连超时,重置状态: userId={}, elapsed={}ms", userId, elapsed);
            reconnectAttempts.remove(userId);
            return true;
        }

        return true;
    }

    public long getNextReconnectDelay(String userId) {
        int attempts = reconnectAttempts.computeIfAbsent(userId, k -> new AtomicInteger(0)).get();

        long delay = baseDelayMs * (long) Math.pow(2, attempts);
        delay = Math.min(delay, maxDelayMs);

        delay += (long) (Math.random() * baseDelayMs);

        return delay;
    }

    public void recordReconnectAttempt(String userId) {
        int attempts = reconnectAttempts.computeIfAbsent(userId, k -> new AtomicInteger(0)).incrementAndGet();

        ReconnectContext context = reconnectContexts.computeIfAbsent(userId, k -> new ReconnectContext());
        context.setAttempts(attempts);

        log.info("记录重连尝试: userId={}, attempt={}, delay={}ms",
            userId, attempts, getNextReconnectDelay(userId));
    }

    public void onConnected(String userId) {
        ReconnectContext context = reconnectContexts.get(userId);
        if (context != null) {
            context.setState(ReconnectState.CONNECTED);
        }
        reconnectAttempts.remove(userId);
        log.info("重连成功: userId={}", userId);
    }

    public void onDisconnected(String userId) {
        ReconnectContext context = reconnectContexts.computeIfAbsent(userId, k -> new ReconnectContext());
        context.setState(ReconnectState.DISCONNECTED);
        context.setDisconnectTime(System.currentTimeMillis());
        log.info("连接断开: userId={}", userId);
    }

    public ReconnectContext getContext(String userId) {
        return reconnectContexts.get(userId);
    }

    public void clearContext(String userId) {
        reconnectContexts.remove(userId);
        reconnectAttempts.remove(userId);
    }

    public enum ReconnectState {
        DISCONNECTED,
        CONNECTING,
        CONNECTED
    }

    @Data
    public static class ReconnectContext {
        private ReconnectState state = ReconnectState.DISCONNECTED;
        private int attempts = 0;
        private long startTime = System.currentTimeMillis();
        private long disconnectTime;
    }
}

6. 心跳调度器

定时检测连接状态:

@Component
@Slf4j
public class HeartbeatScheduler {

    @Autowired
    private WebSocketSessionManager sessionManager;

    @Autowired
    private WebSocketHandler webSocketHandler;

    private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(2);

    private final long heartbeatIntervalMs = 15000;

    private final long heartbeatTimeoutMs = 45000;

    @PostConstruct
    public void init() {
        scheduler.scheduleAtFixedRate(this::checkHeartbeat, heartbeatIntervalMs, heartbeatIntervalMs, TimeUnit.MILLISECONDS);
        log.info("心跳检测已启动: interval={}ms, timeout={}ms", heartbeatIntervalMs, heartbeatTimeoutMs);
    }

    private void checkHeartbeat() {
        Set<String> allUsers = sessionManager.getAllConnectedUsers();
        long now = System.currentTimeMillis();

        for (String userId : allUsers) {
            try {
                Long lastHeartbeat = getLastHeartbeat(userId);
                if (lastHeartbeat == null) {
                    continue;
                }

                long elapsed = now - lastHeartbeat;
                if (elapsed > heartbeatTimeoutMs) {
                    log.warn("心跳超时,触发重连: userId={}, elapsed={}ms", userId, elapsed);
                    handleHeartbeatTimeout(userId);
                }
            } catch (Exception e) {
                log.error("心跳检测异常: userId={}", userId, e);
            }
        }
    }

    private void handleHeartbeatTimeout(String userId) {
        if (sessionManager.isConnected(userId)) {
            WebSocketSession session = sessionManager.getSession(userId);
            if (session != null) {
                try {
                    session.close();
                } catch (IOException e) {
                    log.error("关闭超时会话失败: userId={}", userId, e);
                }
            }
        }

        webSocketHandler.notifyDisconnection(userId);
    }

    private Long getLastHeartbeat(String userId) {
        return null;
    }

    @PreDestroy
    public void shutdown() {
        scheduler.shutdown();
        log.info("心跳检测已关闭");
    }
}

7. 核心 WebSocket 处理器

整合所有组件:

@Component
@Slf4j
public class WebSocketHandler extends TextWebSocketHandler {

    @Autowired
    private WebSocketSessionManager sessionManager;

    @Autowired
    private SequenceManager sequenceManager;

    @Autowired
    private MessageQueueManager messageQueue;

    @Autowired
    private MessageResender messageResender;

    @Autowired
    private ReconnectionManager reconnectionManager;

    @Autowired
    private ChatMessageMapper messageMapper;

    private final ObjectMapper objectMapper = new ObjectMapper();

    @Override
    public void afterConnectionEstablished(WebSocketSession session) {
        String userId = extractUserId(session);
        if (userId == null) {
            try {
                session.close();
            } catch (IOException e) {
                log.error("关闭无效会话失败", e);
            }
            return;
        }

        log.info("WebSocket 连接建立: userId={}, sessionId={}", userId, session.getId());
        sessionManager.registerSession(userId, session);
        reconnectionManager.onConnected(userId);

        Long lastReceivedSeq = sequenceManager.getLastAckedSequence(userId);
        if (lastReceivedSeq != null) {
            log.info("用户重连,发送同步请求: userId={}, lastSeq={}", userId, lastReceivedSeq);
            sendSyncRequest(session, lastReceivedSeq);
        }

        messageResender.resendUnackedMessages(userId);
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) {
        String userId = extractUserId(session);
        if (userId == null) {
            return;
        }

        try {
            WebSocketRequest request = objectMapper.readValue(message.getPayload(), WebSocketRequest.class);
            sessionManager.updateHeartbeat(userId);

            switch (request.getType()) {
                case "CHAT_MESSAGE":
                    handleChatMessage(userId, request);
                    break;
                case "HEARTBEAT":
                    handleHeartbeat(userId, request);
                    break;
                case "ACK":
                    handleAck(userId, request);
                    break;
                case "SYNC_REQUEST":
                    handleSyncRequest(userId, request);
                    break;
                default:
                    log.warn("未知消息类型: userId={}, type={}", userId, request.getType());
            }
        } catch (Exception e) {
            log.error("处理消息异常: userId={}", userId, e);
        }
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
        String userId = extractUserId(session);
        if (userId == null) {
            return;
        }

        log.info("WebSocket 连接关闭: userId={}, sessionId={}, status={}", userId, session.getId(), status);
        sessionManager.removeSession(userId);
        reconnectionManager.onDisconnected(userId);
    }

    private void handleChatMessage(String userId, WebSocketRequest request) {
        ChatMessage chatMessage = request.getMessage();
        if (chatMessage == null) {
            return;
        }

        Long sequenceNo = sequenceManager.generateSendSequence(userId);
        chatMessage.setSequenceNo(sequenceNo);
        chatMessage.setSenderId(userId);
        chatMessage.setSendTime(System.currentTimeMillis());

        messageQueue.addPendingMessage(userId, chatMessage);

        messageMapper.insert(chatMessage);

        for (String receiverId : chatMessage.getReceiverIds()) {
            WebSocketSession receiverSession = sessionManager.getSession(receiverId);
            if (receiverSession != null && receiverSession.isOpen()) {
                try {
                    WebSocketResponse response = WebSocketResponse.success(sequenceNo);
                    receiverSession.sendMessage(new TextMessage(objectMapper.writeValueAsString(response)));
                    log.debug("发送消息给接收者: receiverId={}, sequenceNo={}", receiverId, sequenceNo);
                } catch (IOException e) {
                    log.error("发送消息给接收者失败: receiverId={}", receiverId, e);
                }
            }
        }

        try {
            WebSocketResponse ackResponse = WebSocketResponse.success(sequenceNo);
            WebSocketSession session = sessionManager.getSession(userId);
            if (session != null && session.isOpen()) {
                session.sendMessage(new TextMessage(objectMapper.writeValueAsString(ackResponse)));
            }
        } catch (IOException e) {
            log.error("发送确认失败: userId={}", userId, e);
        }
    }

    private void handleHeartbeat(String userId, WebSocketRequest request) {
        try {
            WebSocketResponse response = WebSocketResponse.heartbeat();
            WebSocketSession session = sessionManager.getSession(userId);
            if (session != null && session.isOpen()) {
                session.sendMessage(new TextMessage(objectMapper.writeValueAsString(response)));
            }
        } catch (IOException e) {
            log.error("发送心跳响应失败: userId={}", userId, e);
        }
    }

    private void handleAck(String userId, WebSocketRequest request) {
        Long sequenceNo = request.getSequenceNo();
        if (sequenceNo == null) {
            return;
        }

        sequenceManager.markAcked(userId, sequenceNo);
        messageQueue.removePendingMessage(userId, sequenceNo);

        log.debug("收到消息确认: userId={}, sequenceNo={}", userId, sequenceNo);
    }

    private void handleSyncRequest(String userId, WebSocketRequest request) {
        Long lastSeq = request.getLastSequenceNo();
        if (lastSeq == null) {
            lastSeq = 0L;
        }

        messageResender.syncMissedMessages(userId, lastSeq);
    }

    private void sendSyncRequest(WebSocketSession session, Long lastSeq) {
        try {
            WebSocketRequest syncRequest = new WebSocketRequest();
            syncRequest.setType("SYNC_REQUEST");
            syncRequest.setLastSequenceNo(lastSeq);
            session.sendMessage(new TextMessage(objectMapper.writeValueAsString(syncRequest)));
        } catch (IOException e) {
            log.error("发送同步请求失败: sessionId={}", session.getId(), e);
        }
    }

    private String extractUserId(WebSocketSession session) {
        String path = session.getUri().getPath();
        String[] parts = path.split("/");
        if (parts.length > 0) {
            return parts[parts.length - 1];
        }
        return null;
    }

    public void notifyDisconnection(String userId) {
    }
}

8. 消息实体和请求响应对象

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatMessage implements Serializable {

    private static final long serialVersionUID = 1L;

    private Long sequenceNo;

    private String messageId;

    private String senderId;

    private List<String> receiverIds;

    private String content;

    private String messageType;

    private Long sendTime;

    private Integer status;
}
@Data
public class WebSocketRequest {

    private String type;

    private ChatMessage message;

    private Long sequenceNo;

    private Long lastSequenceNo;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class WebSocketResponse {

    private String type;

    private boolean success;

    private Long sequenceNo;

    private String errorMessage;

    private Long serverTime;

    public static WebSocketResponse success(Long sequenceNo) {
        return WebSocketResponse.builder()
            .type("ACK")
            .success(true)
            .sequenceNo(sequenceNo)
            .serverTime(System.currentTimeMillis())
            .build();
    }

    public static WebSocketResponse heartbeat() {
        return WebSocketResponse.builder()
            .type("HEARTBEAT")
            .success(true)
            .serverTime(System.currentTimeMillis())
            .build();
    }

    public static WebSocketResponse error(String message) {
        return WebSocketResponse.builder()
            .type("ERROR")
            .success(false)
            .errorMessage(message)
            .serverTime(System.currentTimeMillis())
            .build();
    }
}

9. 数据库访问层

@Mapper
public interface ChatMessageMapper {

    @Insert("INSERT INTO chat_message (message_id, sender_id, receiver_ids, content, message_type, send_time, status, sequence_no) " +
            "VALUES (#{messageId}, #{senderId}, #{receiverIds}, #{content}, #{messageType}, #{sendTime}, #{status}, #{sequenceNo})")
    @Options(useGeneratedKeys = true, keyProperty = "id")
    void insert(ChatMessage message);

    @Select("SELECT * FROM chat_message WHERE sender_id = #{userId} OR receiver_ids LIKE CONCAT('%', #{userId}, '%') ORDER BY sequence_no DESC LIMIT 100")
    List<ChatMessage> selectRecentMessages(@Param("userId") String userId);

    @Select("SELECT * FROM chat_message WHERE sender_id = #{userId} AND sequence_no > #{startSeq} AND sequence_no <= #{endSeq} ORDER BY sequence_no")
    List<ChatMessage> selectBySequenceRange(@Param("userId") String userId, @Param("startSeq") Long startSeq, @Param("endSeq") Long endSeq);
}

10. 配置和数据库表结构

websocket:
  heartbeat:
    interval-seconds: 15
    timeout-seconds: 45
  reconnect:
    max-attempts: 10
    base-delay-ms: 1000
    max-delay-ms: 30000
  message:
    max-pending-count: 1000
    resend-interval-seconds: 5
CREATE TABLE chat_message (
    id BIGINT PRIMARY KEY AUTO_INCREMENT,
    message_id VARCHAR(36) NOT NULL,
    sender_id VARCHAR(36) NOT NULL,
    receiver_ids VARCHAR(500),
    content TEXT,
    message_type VARCHAR(20),
    send_time BIGINT,
    status INT DEFAULT 0,
    sequence_no BIGINT,
    create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    INDEX idx_sender (sender_id),
    INDEX idx_sequence (sender_id, sequence_no),
    INDEX idx_send_time (send_time)
);

客户端重连策略

移动端 App 需要实现以下重连策略:

public class WebSocketClientManager {

    private WebSocketClient webSocketClient;

    private ReconnectionManager reconnectionManager;

    public void connect() {
        String userId = getCurrentUserId();
        String wsUrl = buildWebSocketUrl(userId);

        webSocketClient = new WebSocketClient(new URI(wsUrl)) {
            @Override
            public void onOpen(ServerHandshake handshakedata) {
                reconnectionManager.onConnected(userId);
                sendHeartbeat();
                requestMessageSync();
            }

            @Override
            public void onMessage(String message) {
                handleServerMessage(message);
            }

            @Override
            public void onClose(int code, String reason, boolean remote) {
                reconnectionManager.onDisconnected(userId);
                scheduleReconnect();
            }

            @Override
            public void onError(Exception ex) {
                scheduleReconnect();
            }
        };

        webSocketClient.connect();
    }

    private void scheduleReconnect() {
        if (!reconnectionManager.shouldReconnect(userId)) {
            return;
        }

        long delay = reconnectionManager.getNextReconnectDelay(userId);
        reconnectionManager.recordReconnectAttempt(userId);

        new Handler(Looper.getMainLooper()).postDelayed(() -> {
            connect();
        }, delay);
    }

    private void sendHeartbeat() {
        Map<String, Object> heartbeat = new HashMap<>();
        heartbeat.put("type", "HEARTBEAT");
        heartbeat.put("timestamp", System.currentTimeMillis());
        webSocketClient.send(JSON.toJSONString(heartbeat));

        handler.postDelayed(this::sendHeartbeat, 15000);
    }

    private void requestMessageSync() {
        Long lastSeq = getLocalLastSequenceNo();
        Map<String, Object> syncRequest = new HashMap<>();
        syncRequest.put("type", "SYNC_REQUEST");
        syncRequest.put("lastSequenceNo", lastSeq);
        webSocketClient.send(JSON.toJSONString(syncRequest));
    }
}

实际应用效果

通过这套方案,我们可以实现:

1. 断线自动重连

连接断开 → 检测到 → 等待指数退避 → 重连成功 → 同步消息

2. 消息不丢不重

发送消息 → 进入待确认队列 → 收到 ACK → 移除队列
         → 超时未确认 → 自动重发
         → 收到重复 → 序列号校验 → 丢弃

3. 精准消息同步

重连成功 → 发送最后确认的序列号 → 服务器补发遗漏消息

4. 心跳保活

客户端 → 每15秒发送心跳 → 服务器响应 → 更新心跳时间
      → 45秒无响应 → 判定超时 → 触发重连

总结

通过这套 WebSocket 弱网保活机制,我们可以确保:

  1. 会话管理:统一管理所有 WebSocket 连接
  2. 序列号机制:确保消息不丢不重
  3. 消息队列:暂存未确认消息,支持重发
  4. 智能重连:指数退避策略,避免频繁重连
  5. 心跳检测:及时发现断线,快速响应
  6. 精准同步:重连后补发遗漏消息

希望这篇文章能对你有所帮助,如果你觉得有用,欢迎关注"服务端技术精选",我会持续分享更多实用的技术干货。


标题:SpringBoot + WebSocket 弱网保活机制:App 切后台断线重连,消息精准补发不丢失!
作者:jiangyi
地址:http://www.jiangyi.space/articles/2026/05/09/1777879630541.html
公众号:服务端技术精选
    评论
    0 评论
avatar

取消