diff --git a/main.go b/main.go new file mode 100644 index 0000000..297800b --- /dev/null +++ b/main.go @@ -0,0 +1,11 @@ +package main + +import "github.com/kercylan98/minotaur/server" + +// 无意义的测试main入口 +func main() { + srv := server.New(server.NetworkWebsocket, server.WithConnectPacketDiversion(3, 2)) + if err := srv.Run(":8999"); err != nil { + panic(err) + } +} diff --git a/server/conn.go b/server/conn.go index e7b0ff6..5f58f0c 100644 --- a/server/conn.go +++ b/server/conn.go @@ -4,38 +4,51 @@ import ( "github.com/gorilla/websocket" "github.com/panjf2000/gnet" "github.com/xtaci/kcp-go/v5" + "net" + "strings" ) // newKcpConn 创建一个处理KCP的连接 func newKcpConn(session *kcp.UDPSession) *Conn { - return &Conn{ - ip: session.RemoteAddr().String(), - kcp: session, + c := &Conn{ + remoteAddr: session.RemoteAddr(), + ip: session.RemoteAddr().String(), + kcp: session, write: func(data []byte) error { _, err := session.Write(data) return err }, data: map[any]any{}, } + if index := strings.LastIndex(c.ip, ":"); index != -1 { + c.ip = c.ip[0:index] + } + return c } // newKcpConn 创建一个处理GNet的连接 func newGNetConn(conn gnet.Conn) *Conn { - return &Conn{ - ip: conn.RemoteAddr().String(), - gn: conn, + c := &Conn{ + remoteAddr: conn.RemoteAddr(), + ip: conn.RemoteAddr().String(), + gn: conn, write: func(data []byte) error { return conn.AsyncWrite(data) }, data: map[any]any{}, } + if index := strings.LastIndex(c.ip, ":"); index != -1 { + c.ip = c.ip[0:index] + } + return c } // newKcpConn 创建一个处理WebSocket的连接 func newWebsocketConn(ws *websocket.Conn, ip string) *Conn { return &Conn{ - ip: ip, - ws: ws, + remoteAddr: ws.RemoteAddr(), + ip: ip, + ws: ws, write: func(data []byte) error { return ws.WriteMessage(websocket.BinaryMessage, data) }, @@ -45,15 +58,24 @@ func newWebsocketConn(ws *websocket.Conn, ip string) *Conn { // Conn 服务器连接 type Conn struct { - ip string - ws *websocket.Conn - gn gnet.Conn - kcp *kcp.UDPSession - write func(data []byte) error - data map[any]any + remoteAddr net.Addr + ip string + ws *websocket.Conn + gn gnet.Conn + kcp *kcp.UDPSession + write func(data []byte) error + data map[any]any +} + +func (slf *Conn) RemoteAddr() net.Addr { + return slf.remoteAddr } func (slf *Conn) GetID() string { + return slf.remoteAddr.String() +} + +func (slf *Conn) GetIP() string { return slf.ip } diff --git a/server/errors.go b/server/errors.go index 40eeafb..cf1da3b 100644 --- a/server/errors.go +++ b/server/errors.go @@ -16,4 +16,5 @@ var ( ErrWebsocketIllegalMessageType = errors.New("illegal message type") ErrPleaseUseWebsocketHandle = errors.New("in Websocket mode, please use the RegConnectionReceiveWebsocketPacketEvent function to register") ErrPleaseUseOrdinaryPacketHandle = errors.New("non Websocket mode, please use the RegConnectionReceivePacketEvent function to register") + ErrOnlySupportSocket = errors.New("only supports Socket programming") ) diff --git a/server/event.go b/server/event.go index 10ad20e..138cac5 100644 --- a/server/event.go +++ b/server/event.go @@ -74,7 +74,11 @@ func (slf *event) RegConnectionOpenedEvent(handle ConnectionOpenedEventHandle) { } func (slf *event) OnConnectionOpenedEvent(conn *Conn) { - log.Debug("Server", zap.String("ConnectionOpened", conn.GetID())) + if len(slf.diversionMessageChannels) == 0 { + log.Debug("Server", zap.String("ConnectionOpened", conn.GetID())) + } else { + log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()), zap.Int("Node", slf.diversionConsistency.PickNode(conn.GetID()))) + } for _, handle := range slf.connectionOpenedEventHandles { handle(slf.Server, conn) } diff --git a/server/options.go b/server/options.go index 6ef0b26..43e8a58 100644 --- a/server/options.go +++ b/server/options.go @@ -1,6 +1,7 @@ package server import ( + "github.com/kercylan98/minotaur/utils/hash" "github.com/kercylan98/minotaur/utils/log" "go.uber.org/zap" "google.golang.org/grpc" @@ -21,6 +22,25 @@ const ( type Option func(srv *Server) +// WithConnectPacketDiversion 通过连接数据包消息分流的方式创建服务器 +// - 连接消息分流后数据包消息将会从其他消息类型中独立出来,并且由多个消息管道及协程进行处理 +// - 默认不会进行消息分流 +// - 需要注意并发编程 +func WithConnectPacketDiversion(diversionNumber, channelSize int) Option { + return func(srv *Server) { + if srv.network == NetworkHttp || srv.network == NetworkGRPC { + log.Warn("WithConnectPacketDiversion", zap.String("Network", string(srv.network)), zap.Error(ErrOnlySupportSocket)) + return + } + srv.diversionMessageChannels = make([]chan *message, diversionNumber) + srv.diversionConsistency = hash.NewConsistency(3) + for i := 0; i < diversionNumber; i++ { + srv.diversionMessageChannels[i] = make(chan *message, channelSize) + srv.diversionConsistency.AddNode(i + 1) + } + } +} + // WithTLS 通过安全传输层协议TLS创建服务器 // - 支持:Http、Websocket func WithTLS(certFile, keyFile string) Option { diff --git a/server/server.go b/server/server.go index d522ce2..61cdba4 100644 --- a/server/server.go +++ b/server/server.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" + "github.com/kercylan98/minotaur/utils/hash" "github.com/kercylan98/minotaur/utils/log" "github.com/kercylan98/minotaur/utils/synchronization" "github.com/panjf2000/gnet" @@ -62,14 +63,16 @@ type Server struct { isShutdown atomic.Bool // 是否已关闭 closeChannel chan struct{} // 关闭信号 - gServer *gNet // TCP或UDP模式下的服务器 - messagePool *synchronization.Pool[*message] // 消息池 - messagePoolSize int // 消息池大小 - messageChannel chan *message // 消息管道 - initMessageChannel bool // 消息管道是否已经初始化 - multiple bool // 是否为多服务器模式下运行 - prod bool // 是否为生产模式 - core int // 消息处理核心数 + gServer *gNet // TCP或UDP模式下的服务器 + messagePool *synchronization.Pool[*message] // 消息池 + messagePoolSize int // 消息池大小 + messageChannel chan *message // 消息管道 + initMessageChannel bool // 消息管道是否已经初始化 + multiple bool // 是否为多服务器模式下运行 + prod bool // 是否为生产模式 + core int // 消息处理核心数 + diversionMessageChannels []chan *message // 分流消息管道 + diversionConsistency *hash.Consistency // 哈希一致性分流器 } // Run 使用特定地址运行服务器 @@ -118,6 +121,15 @@ func (slf *Server) Run(addr string) error { slf.dispatchMessage(message) } }() + go func() { + for i := 0; i < len(slf.diversionMessageChannels); i++ { + go func(channel chan *message) { + for message := range channel { + slf.dispatchMessage(message) + } + }(slf.diversionMessageChannels[i]) + } + }() } } @@ -249,7 +261,7 @@ func (slf *Server) Run(addr string) error { if err != nil { panic(err) } - if !slf.supportMessageTypes[messageType] { + if len(slf.supportMessageTypes) > 0 && !slf.supportMessageTypes[messageType] { panic(ErrWebsocketIllegalMessageType) } slf.PushMessage(MessageTypePacket, conn, packet, messageType) @@ -311,6 +323,11 @@ func (slf *Server) IsDev() bool { // Shutdown 停止运行服务器 func (slf *Server) Shutdown(err error) { slf.isShutdown.Store(true) + if len(slf.diversionMessageChannels) > 0 { + for i := 0; i < len(slf.diversionMessageChannels); i++ { + close(slf.diversionMessageChannels[i]) + } + } if slf.initMessageChannel { if slf.gServer != nil { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -364,7 +381,12 @@ func (slf *Server) PushMessage(messageType MessageType, attrs ...any) { msg := slf.messagePool.Get() msg.t = messageType msg.attrs = attrs - slf.messageChannel <- msg + if messageType == MessageTypePacket && len(slf.diversionMessageChannels) > 0 { + conn := attrs[0].(*Conn) + slf.diversionMessageChannels[slf.diversionConsistency.PickNode(conn.ip)] <- msg + } else { + slf.messageChannel <- msg + } } // dispatchMessage 消息分发 @@ -381,6 +403,9 @@ func (slf *Server) dispatchMessage(msg *message) { case MessageTypePacket: if slf.network == NetworkWebsocket { conn, packet, messageType := msg.t.deconstructWebSocketPacket(msg.attrs...) + if slf.diversionConsistency != nil { + slf.diversionConsistency.PickNode(conn) + } slf.OnConnectionReceiveWebsocketPacketEvent(conn, packet, messageType) } else { conn, packet := msg.t.deconstructPacket(msg.attrs...) diff --git a/utils/hash/consistency.go b/utils/hash/consistency.go index b1e2cfb..6c2f310 100644 --- a/utils/hash/consistency.go +++ b/utils/hash/consistency.go @@ -8,11 +8,17 @@ import ( "strings" ) +func NewConsistency(replicas int) *Consistency { + return &Consistency{ + replicas: replicas, + } +} + // Consistency 一致性哈希生成 // // https://blog.csdn.net/zhpCSDN921011/article/details/126845397 type Consistency struct { - Replicas int // 虚拟节点的数量 + replicas int // 虚拟节点的数量 keys []int // 所有虚拟节点的哈希值 hashMap map[int]int // 虚拟节点的哈希值: 节点(虚拟节点映射到真实节点) } @@ -22,11 +28,11 @@ func (slf *Consistency) AddNode(keys ...int) { if slf.hashMap == nil { slf.hashMap = map[int]int{} } - if slf.Replicas == 0 { - slf.Replicas = 3 + if slf.replicas == 0 { + slf.replicas = 3 } for _, key := range keys { - for i := 0; i < slf.Replicas; i++ { + for i := 0; i < slf.replicas; i++ { // 计算虚拟节点哈希值 hash := int(crc32.ChecksumIEEE([]byte(strconv.Itoa(i) + strconv.Itoa(key))))