Spring 封装

pom.xml

<!-- websocket -->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

HttpAuthInterceptor

package com.example.websocket;

import cn.hutool.http.HttpUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

import java.nio.charset.StandardCharsets;
import java.util.Map;

@Component
@Slf4j
public class HttpAuthInterceptor implements HandshakeInterceptor {
    /**
     * 握手前
     *
     * @param request    请求
     * @param response   响应
     * @param wsHandler  WebSocket处理器
     * @param attributes 用于传递信息的map,此处用于给session中放入用户信息
     * @return 是否通过
     */
    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) {
        // 获得请求参数
        Map<String, String> paramMap = HttpUtil.decodeParamMap(request.getURI().getQuery(), StandardCharsets.UTF_8);
        if (paramMap.containsKey("token")) {
            // 放入属性域
            attributes.put("token", paramMap.get("token"));
            log.info("用户握手成功");
            return true;
        }
        return false;
    }

    /**
     * 握手后
     *
     * @param request   请求
     * @param response  响应
     * @param wsHandler WebSocket处理器
     * @param exception 异常
     */
    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
        System.out.println("握手结束");
    }
}

Websocket

package com.example.websocket;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.jetbrains.annotations.NotNull;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Component
@Slf4j
public class WebSocket extends TextWebSocketHandler {
    // 保存session
    private static final Map<String, WebSocketSession> SESSION_MAP = new ConcurrentHashMap<>();

    /**
     * 建立连接成功事件
     * @param session session
     * @throws Exception 异常
     */
    @Override
    public void afterConnectionEstablished(@NonNull WebSocketSession session) throws Exception {
        Object token = session.getAttributes().get("token");
        if (token == null) {
            // 没有token,关闭连接
            session.close();
            return;
        }
        // 保存session
        SESSION_MAP.put(token.toString(), session);
    }

    /**
     * 接收消息事件
     * @param session session
     * @param message 消息
     */
    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws IOException {
        Object payload = message.getPayload();
        log.info("收到消息:{}", payload);
        session.sendMessage(new TextMessage("收到消息:" + payload));
    }

    @Override
    public void handleTransportError(@NonNull WebSocketSession session, @NonNull Throwable exception) throws Exception {
        //异常处理
    }

    /**
     * 连接关闭事件
     * @param session session
     * @param closeStatus 关闭状态
     * @throws Exception 异常
     */
    @Override
    public void afterConnectionClosed(@NonNull WebSocketSession session, @NonNull CloseStatus closeStatus) throws Exception {
        Object token = session.getAttributes().get("token");
        if (token != null) {
            // 移除session
            SESSION_MAP.remove(token.toString());
        }
    }

    @Override
    public boolean supportsPartialMessages() {
        // 是否支持接收不完整的消息
        return false;
    }

    // 单独像某个用户发送消息
    public void sendMessageToUser(@NotNull String id, @NotNull String message) throws IOException {
        val session = SESSION_MAP.get(id);
        if (session != null) {
            session.sendMessage(new TextMessage(message));
        }
    }

    // 广播消息
    public void sendMessageToAllUsers(@NotNull String message) throws IOException {
        for (val session : SESSION_MAP.values()) {
            session.sendMessage(new TextMessage(message));
        }
    }
}

WebsocketConfig

package com.example.config;

import com.example.websocket.HttpAuthInterceptor;
import com.example.websocket.WebSocket;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.*;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;

@Configuration
@EnableWebSocket
@RequiredArgsConstructor
public class WebSocketConfig implements WebSocketConfigurer {
    private final HttpAuthInterceptor httpAuthInterceptor;
    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(new WebSocket(), "")
                .addInterceptors(httpAuthInterceptor)
                .setAllowedOrigins("*")
                .addInterceptors(new HttpSessionHandshakeInterceptor());
    }
}

最后更新于

这有帮助吗?