package com.github.paicoding.forum.web.front.chat.ws;

import com.github.paicoding.forum.api.model.context.ReqInfoContext;
import com.github.paicoding.forum.core.mdc.MdcUtil;
import com.github.paicoding.forum.core.util.SessionUtil;
import com.github.paicoding.forum.core.util.SpringUtil;
import com.github.paicoding.forum.web.global.GlobalInitService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;

import java.util.Map;

/**
 * v1. 简单版本聊天: 长连接的登录校验拦截器
 *
 * @author YiHui
 * @date 2023/6/6
 */
@Slf4j
public class SimpleWsAuthInterceptor extends HttpSessionHandshakeInterceptor implements ChannelInterceptor {

    @Override
    public boolean preReceive(MessageChannel channel) {
        return ChannelInterceptor.super.preReceive(channel);
    }

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        String session = ((ServletServerHttpRequest) request).getServletRequest().getParameter("session");
        ReqInfoContext.ReqInfo reqInfo = new ReqInfoContext.ReqInfo();
        SpringUtil.getBean(GlobalInitService.class).initLoginUser(session, reqInfo);
        ReqInfoContext.addReqInfo(reqInfo);
        if (reqInfo.getUserId() == null) {
            // 未登录,拒绝链接
            log.info("用户未登录,拒绝聊天!");
            response.setStatusCode(HttpStatus.FORBIDDEN);
            return false;
        }
        log.info("{} 开始了聊天!", reqInfo);
        MdcUtil.addTraceId();
        return super.beforeHandshake(request, response, wsHandler, attributes);
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception ex) {
        ReqInfoContext.clear();
        MdcUtil.clear();
        super.afterHandshake(request, response, wsHandler, ex);
    }
}

这段代码定义了一个名为SimpleWsAuthInterceptor的类,它继承自HttpSessionHandshakeInterceptor并实现了ChannelInterceptor接口。这个拦截器用于在基于HTTP的WebSocket连接建立过程中进行用户身份验证,并在消息通道中进行权限检查。

主要方法详解

  1. preReceive(MessageChannel channel):

    • 此方法在消息接收之前被调用,用于在消息通道中进行权限检查。
    • 目前,这个方法直接调用了父类的实现,并返回了结果。
  2. beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes):

    • 此方法在WebSocket握手之前被调用,用于进行用户身份验证。
    • 从请求中提取session参数,这个参数应该是用户的会话标识。
    • 创建ReqInfoContext.ReqInfo对象,并调用GlobalInitService.initLoginUser方法来初始化用户信息。
    • 将初始化的用户信息添加到ReqInfoContext中。
    • 如果用户未登录(reqInfo.getUserId() == null),则拒绝连接,并设置响应状态为HttpStatus.FORBIDDEN
    • 如果用户已登录,记录日志并继续执行握手过程。
  3. afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception ex):

    • 此方法在WebSocket握手之后被调用,无论握手是否成功。
    • 清除ReqInfoContextMdcUtil中的信息,以防止潜在的安全问题。
    • 调用父类的方法来完成握手后的清理工作。

使用场景

SimpleWsAuthInterceptor拦截器适用于需要在WebSocket连接建立之前进行用户身份验证的场景。例如,在聊天应用中,只有验证通过的用户才能建立WebSocket连接,以确保通信的安全性。

注意事项

  • 拦截器中的权限验证逻辑需要根据实际的业务需求进行定制。
  • 在实际部署时,需要确保所有的WebSocket通信都通过这个拦截器进行,以保证安全性。
  • 日志记录语句应该根据实际的日志策略进行配置,以避免过多的日志输出影响性能。
  • session参数的验证逻辑应该足够健壮,以防止会话劫持等安全问题。
  • MdcUtil.addTraceId()方法用于添加追踪ID,这有助于在分布式追踪系统中跟踪请求的完整路径。

SimpleWsAuthInterceptor拦截器是WebSocket通信中一个重要的安全组件,它可以帮助开发者在连接建立阶段进行用户身份验证,确保只有合法的用户能够参与到WebSocket通信中。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐