修复Websocket消息类型过滤不设置时无法接收数据包的问题,服务器增加连接分流功能

This commit is contained in:
kercylan98 2023-05-15 10:01:09 +08:00
parent df4aa30743
commit 926b69bee1
7 changed files with 118 additions and 29 deletions

11
main.go Normal file
View File

@ -0,0 +1,11 @@
package main
import "github.com/kercylan98/minotaur/server"
// 无意义的测试main入口
func main() {
srv := server.New(server.NetworkWebsocket, server.WithConnectPacketDiversion(3, 2))
if err := srv.Run(":8999"); err != nil {
panic(err)
}
}

View File

@ -4,11 +4,14 @@ import (
"github.com/gorilla/websocket"
"github.com/panjf2000/gnet"
"github.com/xtaci/kcp-go/v5"
"net"
"strings"
)
// newKcpConn 创建一个处理KCP的连接
func newKcpConn(session *kcp.UDPSession) *Conn {
return &Conn{
c := &Conn{
remoteAddr: session.RemoteAddr(),
ip: session.RemoteAddr().String(),
kcp: session,
write: func(data []byte) error {
@ -17,11 +20,16 @@ func newKcpConn(session *kcp.UDPSession) *Conn {
},
data: map[any]any{},
}
if index := strings.LastIndex(c.ip, ":"); index != -1 {
c.ip = c.ip[0:index]
}
return c
}
// newKcpConn 创建一个处理GNet的连接
func newGNetConn(conn gnet.Conn) *Conn {
return &Conn{
c := &Conn{
remoteAddr: conn.RemoteAddr(),
ip: conn.RemoteAddr().String(),
gn: conn,
write: func(data []byte) error {
@ -29,11 +37,16 @@ func newGNetConn(conn gnet.Conn) *Conn {
},
data: map[any]any{},
}
if index := strings.LastIndex(c.ip, ":"); index != -1 {
c.ip = c.ip[0:index]
}
return c
}
// newKcpConn 创建一个处理WebSocket的连接
func newWebsocketConn(ws *websocket.Conn, ip string) *Conn {
return &Conn{
remoteAddr: ws.RemoteAddr(),
ip: ip,
ws: ws,
write: func(data []byte) error {
@ -45,6 +58,7 @@ func newWebsocketConn(ws *websocket.Conn, ip string) *Conn {
// Conn 服务器连接
type Conn struct {
remoteAddr net.Addr
ip string
ws *websocket.Conn
gn gnet.Conn
@ -53,7 +67,15 @@ type Conn struct {
data map[any]any
}
func (slf *Conn) RemoteAddr() net.Addr {
return slf.remoteAddr
}
func (slf *Conn) GetID() string {
return slf.remoteAddr.String()
}
func (slf *Conn) GetIP() string {
return slf.ip
}

View File

@ -16,4 +16,5 @@ var (
ErrWebsocketIllegalMessageType = errors.New("illegal message type")
ErrPleaseUseWebsocketHandle = errors.New("in Websocket mode, please use the RegConnectionReceiveWebsocketPacketEvent function to register")
ErrPleaseUseOrdinaryPacketHandle = errors.New("non Websocket mode, please use the RegConnectionReceivePacketEvent function to register")
ErrOnlySupportSocket = errors.New("only supports Socket programming")
)

View File

@ -74,7 +74,11 @@ func (slf *event) RegConnectionOpenedEvent(handle ConnectionOpenedEventHandle) {
}
func (slf *event) OnConnectionOpenedEvent(conn *Conn) {
if len(slf.diversionMessageChannels) == 0 {
log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()))
} else {
log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()), zap.Int("Node", slf.diversionConsistency.PickNode(conn.GetID())))
}
for _, handle := range slf.connectionOpenedEventHandles {
handle(slf.Server, conn)
}

View File

@ -1,6 +1,7 @@
package server
import (
"github.com/kercylan98/minotaur/utils/hash"
"github.com/kercylan98/minotaur/utils/log"
"go.uber.org/zap"
"google.golang.org/grpc"
@ -21,6 +22,25 @@ const (
type Option func(srv *Server)
// WithConnectPacketDiversion 通过连接数据包消息分流的方式创建服务器
// - 连接消息分流后数据包消息将会从其他消息类型中独立出来,并且由多个消息管道及协程进行处理
// - 默认不会进行消息分流
// - 需要注意并发编程
func WithConnectPacketDiversion(diversionNumber, channelSize int) Option {
return func(srv *Server) {
if srv.network == NetworkHttp || srv.network == NetworkGRPC {
log.Warn("WithConnectPacketDiversion", zap.String("Network", string(srv.network)), zap.Error(ErrOnlySupportSocket))
return
}
srv.diversionMessageChannels = make([]chan *message, diversionNumber)
srv.diversionConsistency = hash.NewConsistency(3)
for i := 0; i < diversionNumber; i++ {
srv.diversionMessageChannels[i] = make(chan *message, channelSize)
srv.diversionConsistency.AddNode(i + 1)
}
}
}
// WithTLS 通过安全传输层协议TLS创建服务器
// - 支持Http、Websocket
func WithTLS(certFile, keyFile string) Option {

View File

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/utils/hash"
"github.com/kercylan98/minotaur/utils/log"
"github.com/kercylan98/minotaur/utils/synchronization"
"github.com/panjf2000/gnet"
@ -70,6 +71,8 @@ type Server struct {
multiple bool // 是否为多服务器模式下运行
prod bool // 是否为生产模式
core int // 消息处理核心数
diversionMessageChannels []chan *message // 分流消息管道
diversionConsistency *hash.Consistency // 哈希一致性分流器
}
// Run 使用特定地址运行服务器
@ -118,6 +121,15 @@ func (slf *Server) Run(addr string) error {
slf.dispatchMessage(message)
}
}()
go func() {
for i := 0; i < len(slf.diversionMessageChannels); i++ {
go func(channel chan *message) {
for message := range channel {
slf.dispatchMessage(message)
}
}(slf.diversionMessageChannels[i])
}
}()
}
}
@ -249,7 +261,7 @@ func (slf *Server) Run(addr string) error {
if err != nil {
panic(err)
}
if !slf.supportMessageTypes[messageType] {
if len(slf.supportMessageTypes) > 0 && !slf.supportMessageTypes[messageType] {
panic(ErrWebsocketIllegalMessageType)
}
slf.PushMessage(MessageTypePacket, conn, packet, messageType)
@ -311,6 +323,11 @@ func (slf *Server) IsDev() bool {
// Shutdown 停止运行服务器
func (slf *Server) Shutdown(err error) {
slf.isShutdown.Store(true)
if len(slf.diversionMessageChannels) > 0 {
for i := 0; i < len(slf.diversionMessageChannels); i++ {
close(slf.diversionMessageChannels[i])
}
}
if slf.initMessageChannel {
if slf.gServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
@ -364,7 +381,12 @@ func (slf *Server) PushMessage(messageType MessageType, attrs ...any) {
msg := slf.messagePool.Get()
msg.t = messageType
msg.attrs = attrs
if messageType == MessageTypePacket && len(slf.diversionMessageChannels) > 0 {
conn := attrs[0].(*Conn)
slf.diversionMessageChannels[slf.diversionConsistency.PickNode(conn.ip)] <- msg
} else {
slf.messageChannel <- msg
}
}
// dispatchMessage 消息分发
@ -381,6 +403,9 @@ func (slf *Server) dispatchMessage(msg *message) {
case MessageTypePacket:
if slf.network == NetworkWebsocket {
conn, packet, messageType := msg.t.deconstructWebSocketPacket(msg.attrs...)
if slf.diversionConsistency != nil {
slf.diversionConsistency.PickNode(conn)
}
slf.OnConnectionReceiveWebsocketPacketEvent(conn, packet, messageType)
} else {
conn, packet := msg.t.deconstructPacket(msg.attrs...)

View File

@ -8,11 +8,17 @@ import (
"strings"
)
func NewConsistency(replicas int) *Consistency {
return &Consistency{
replicas: replicas,
}
}
// Consistency 一致性哈希生成
//
// https://blog.csdn.net/zhpCSDN921011/article/details/126845397
type Consistency struct {
Replicas int // 虚拟节点的数量
replicas int // 虚拟节点的数量
keys []int // 所有虚拟节点的哈希值
hashMap map[int]int // 虚拟节点的哈希值: 节点(虚拟节点映射到真实节点)
}
@ -22,11 +28,11 @@ func (slf *Consistency) AddNode(keys ...int) {
if slf.hashMap == nil {
slf.hashMap = map[int]int{}
}
if slf.Replicas == 0 {
slf.Replicas = 3
if slf.replicas == 0 {
slf.replicas = 3
}
for _, key := range keys {
for i := 0; i < slf.Replicas; i++ {
for i := 0; i < slf.replicas; i++ {
// 计算虚拟节点哈希值
hash := int(crc32.ChecksumIEEE([]byte(strconv.Itoa(i) + strconv.Itoa(key))))