spring boot 实现直播聊天室

技术方案:

  • spring boot
  • websocket
  • rabbitmq

使用 rabbitmq 提高系统吞吐量

引入依赖

<dependencies>
    <dependency>
        <groupId>com.alibaba</groupId>
        <artifactId>fastjson</artifactId>
        <version>2.0.42</version>
    </dependency>
    <dependency>
        <groupId>cn.hutool</groupId>
        <artifactId>hutool-all</artifactId>
        <version>5.8.23</version>
    </dependency>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-websocket</artifactId>
    </dependency>
    <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
        <optional>true</optional>
    </dependency>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-test</artifactId>
        <scope>test</scope>
    </dependency>
    <dependency>
        <groupId>org.springframework.amqp</groupId>
        <artifactId>spring-rabbit</artifactId>
    </dependency>
</dependencies>

websocket 实现

MHttpSessionHandshakeInterceptor

参数拦截

/**
 * @Date: 2023/12/8 14:52
 * websocket 握手拦截
 * 1. 参数拦截(header或者 url 参数)
 * 2. token 校验
 */
@Slf4j
public class MHttpSessionHandshakeInterceptor extends HttpSessionHandshakeInterceptor {

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
                                   WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        if (request instanceof ServletServerHttpRequest servletRequest){
            //ws://127.0.0.1:8080/group/2?username=xxxx
            HttpServletRequest httpServletRequest = servletRequest.getServletRequest();
            String requestURI = httpServletRequest.getRequestURI();
            String groupId = requestURI.substring(requestURI.lastIndexOf("/") + 1);
            String username = httpServletRequest.getParameter("username");
            log.info(">>>>>>> beforeHandshake groupId: {} - username: {}", groupId, username);
            attributes.put("username", username);
            //解析占位符
            attributes.put("groupId", groupId);
        }
        return super.beforeHandshake(request, response, wsHandler, attributes);
    }


}

GroupWebSocketHandler

消息发送

@Slf4j
public class GroupWebSocketHandler implements WebSocketHandler {

    //Map<room,List<map<session,username>>>
    private ConcurrentHashMap<String, Queue<WebSocketSession>> sessionMap = new ConcurrentHashMap<>();

    @Autowired
    private MessageClient messagingClient;


    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        String username = (String) session.getAttributes().get("username");
        String groupId = (String) session.getAttributes().get("groupId");
        log.info("{} 用户上线房间 {}", username, groupId);
        TomcatWsSession wsSession = new TomcatWsSession(session.getId(),groupId, username, session);
        SessionRegistry.getInstance().addSession(wsSession);
    }

    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        String groupId = (String) session.getAttributes().get("groupId");
        String username = (String) session.getAttributes().get("username");
        if (message instanceof PingMessage){
            log.info("PING");
            return;
        }
        else if (message instanceof TextMessage textMessage) {
            MessageDto messageDto = new MessageDto();
            messageDto.setSessionId(session.getId());
            messageDto.setGroup(groupId);
            messageDto.setFromUser(username);
            messageDto.setContent(new String(textMessage.getPayload()));
            messagingClient.sendMessage(messageDto);
        }
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        String username = (String) session.getAttributes().get("username");
        String groupId = (String) session.getAttributes().get("groupId");
        log.info(">>> handleTransportError {} 用户上线房间 {}", username, groupId);
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
        String username = (String) session.getAttributes().get("username");
        String groupId = (String) session.getAttributes().get("groupId");
        log.info("{} 用户下线房间 {}", username, groupId);
        TomcatWsSession wsSession = new TomcatWsSession(session.getId(),groupId, username, session);
        SessionRegistry.getInstance().removeSession(wsSession);
    }

    @Override
    public boolean supportsPartialMessages() {
        return false;
    }


}
WebSocketConfig

websocket 配置

@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {


    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(myHandler(), "/group/{groupId}")
            .addInterceptors(new MHttpSessionHandshakeInterceptor()).setAllowedOrigins("*");
    }

    @Bean
    public GroupWebSocketHandler myHandler() {
        return new GroupWebSocketHandler();
    }


    @Bean
    public ServletServerContainerFactoryBean createWebSocketContainer() {
        ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
        container.setMaxTextMessageBufferSize(8192);  //文本消息最大缓存
        container.setMaxBinaryMessageBufferSize(8192);  //二进制消息大战缓存
        container.setMaxSessionIdleTimeout(3L * 60 * 1000); // 最大闲置时间,3分钟没动自动关闭连接
        container.setAsyncSendTimeout(10L * 1000); //异步发送超时时间
        return container;
    }

}

session 管理

将 websocketSession进行抽像,websocketsession可以由不同容器实现

WsSession
public interface  WsSession {

    /**
     * session 组
     * @return
     */
    String group();

    /**
     * session Id
     * @return
     */
    String getId();

    /**
     * 用户名或其他唯一标识
     * @return
     */
    String identity();

    /**
     * 发送文本消息
     * @param messageDto
     */

    void sendTextMessage(MessageDto messageDto);
}

public abstract class AbstractWsSession implements WsSession {

    private String id;
    private String group;

    private String identity;

    public AbstractWsSession(String id, String group, String identity) {
        this.id = id;
        this.group = group;
        this.identity = identity;
    }

    @Override
    public String group() {
        return this.group;
    }

    @Override
    public String getId() {
        return this.id;
    }

    @Override
    public String identity() {
        return this.identity;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        AbstractWsSession that = (AbstractWsSession) o;
        //简单比较 sessionId
        return Objects.equals(id, that.id);
    }

    @Override
    public int hashCode() {
        return Objects.hash(id, group, identity);
    }
}

TomcatWsSession

默认session实现

@Slf4j
public class TomcatWsSession extends AbstractWsSession {

    private WebSocketSession webSocketSession;

    public TomcatWsSession(String id, String group, String identity, WebSocketSession webSocketSession) {
        super(id, group, identity);
        this.webSocketSession = webSocketSession;
    }

    @Override
    public void sendTextMessage(MessageDto messageDto) {
        String content = messageDto.getFromUser() + " say: " + messageDto.getContent();
        try {
            webSocketSession.sendMessage(new TextMessage(content));
        } catch (IOException e) {
            log.error("TomcatWsSession sendTextMessage error: identity:{}-group:{}-msg: {}",
                    super.identity(), super.group(), JSON.toJSONString(messageDto));
        }

    }
}

SessionRegistry

websocket session管理

public class SessionRegistry {

    private static SessionRegistry instance;

    private SessionRegistry() {

    }

    public static SessionRegistry getInstance() {
        if (instance == null) {
            synchronized (SessionRegistry.class) {
                if (instance == null) {
                    instance = new SessionRegistry();
                }
            }
        }
        return instance;
    }


    //Map<group,List<Session>>
    private ConcurrentHashMap<String, Queue<WsSession>> sessionMap = new ConcurrentHashMap<>();


    /**
     * 添加 session
     * @param wsSession
     */
    public void addSession(WsSession wsSession) {
        sessionMap.computeIfAbsent(wsSession.group(),g -> new ConcurrentLinkedDeque<>()).add(wsSession);
    }

    /**
     * 移除 session
     * @param wsSession
     */
    public void removeSession(WsSession wsSession) {
        Queue<WsSession> wsSessions = sessionMap.get(wsSession.group());
        if (!CollectionUtils.isEmpty(wsSessions)){
            //重写 WsSession equals 和 hashCode 方法,不然会移除失败
            wsSessions.remove(wsSession);
            if (CollectionUtils.isEmpty(wsSessions)){
                sessionMap.remove(wsSession.group());
            }
        }
    }

    /**
     * 发送消息
     * @param messageDto
     */
    public void sendGroupTextMessage(MessageDto messageDto){
        Queue<WsSession> wsSessions = sessionMap.get(messageDto.getGroup());
        if (!CollectionUtils.isEmpty(wsSessions)){
            for (WsSession wsSession : wsSessions) {
                if (wsSession.getId().equals(messageDto.getSessionId())){
                    continue;
                }
                wsSession.sendTextMessage(messageDto);
            }
        }
    }


    /**
     * session 在线统计
     * @param groupId
     * @return
     */
    public Integer getSessionCount(String groupId) {
        if (StrUtil.isNotBlank(groupId)) {
            return sessionMap.get(groupId).size();
        }
        return sessionMap.values().stream().map(l -> l.size()).collect(Collectors.summingInt(a -> a));
    }
}

消息队列

这里使用 rabbitmq

MessageDto

消息体

@Data
public class MessageDto {

    /**
     * sessionId
     */
    private String sessionId;
    /**
     * 组
     */
    private String group;
    /**
     * 消息发送者
     */
    private String fromUser;
    /**
     * 发送内容
     */
    private String content;
}
MessageClient
@Component
@Slf4j
public class MessageClient {

    private String routeKey = "bws.key";
    private String exchange = "bws.exchange";

    @Autowired
    private RabbitTemplate rabbitTemplate;


    public void sendMessage(MessageDto messageDto) {
        try {
            rabbitTemplate.convertAndSend(exchange, routeKey, JSON.toJSONString(messageDto));
        } catch (AmqpException e) {
            log.error("MessageClient.sendMessage: {}", JSON.toJSONString(messageDto), e);
        }
    }
}
MessageListener
@Slf4j
@Component
public class MessageListener {

    @RabbitListener(bindings = @QueueBinding(exchange = @Exchange(value = "bws.exchange", type = "topic"), value =
    @Queue(value = "bws.queue", durable = "true"), key = "bws.key"))
    public void onMessage(Message message) {
        String messageStr = "";
        try {
            messageStr = new String(message.getBody(), StandardCharsets.UTF_8);
            log.info("<<<<<<<<< MessageListener.onMessage:{}", messageStr);
            MessageDto messageDto = JSON.parseObject(messageStr, MessageDto.class);
            if (!Objects.isNull(messageDto)) {
                SessionRegistry.getInstance().sendGroupTextMessage(messageDto);
            } else {
                log.info("<<<<<<<<< MessageListener.onMessage is null:{}", messageStr);
            }
        } catch (Exception e) {
            log.error("######### MessageListener.onMessage: {}-{}", messageStr, e);
        }
    }

}

application.properties配置


spring.rabbitmq.host=192.168.x.x
spring.rabbitmq.password=guest
spring.rabbitmq.port=27067
spring.rabbitmq.username=guest
spring.rabbitmq.virtual-host=my-cluster

测试

websoket链接: ws://127.0.0.1:8080/group/2?username=xxx, websocket客户端测试地址

在这里插入图片描述

good luck!

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐