修复Websocket消息类型过滤不设置时无法接收数据包的问题,服务器增加连接分流功能
This commit is contained in:
parent
df4aa30743
commit
926b69bee1
|
@ -0,0 +1,11 @@
|
|||
package main
|
||||
|
||||
import "github.com/kercylan98/minotaur/server"
|
||||
|
||||
// 无意义的测试main入口
|
||||
func main() {
|
||||
srv := server.New(server.NetworkWebsocket, server.WithConnectPacketDiversion(3, 2))
|
||||
if err := srv.Run(":8999"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
|
@ -4,11 +4,14 @@ import (
|
|||
"github.com/gorilla/websocket"
|
||||
"github.com/panjf2000/gnet"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// newKcpConn 创建一个处理KCP的连接
|
||||
func newKcpConn(session *kcp.UDPSession) *Conn {
|
||||
return &Conn{
|
||||
c := &Conn{
|
||||
remoteAddr: session.RemoteAddr(),
|
||||
ip: session.RemoteAddr().String(),
|
||||
kcp: session,
|
||||
write: func(data []byte) error {
|
||||
|
@ -17,11 +20,16 @@ func newKcpConn(session *kcp.UDPSession) *Conn {
|
|||
},
|
||||
data: map[any]any{},
|
||||
}
|
||||
if index := strings.LastIndex(c.ip, ":"); index != -1 {
|
||||
c.ip = c.ip[0:index]
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// newKcpConn 创建一个处理GNet的连接
|
||||
func newGNetConn(conn gnet.Conn) *Conn {
|
||||
return &Conn{
|
||||
c := &Conn{
|
||||
remoteAddr: conn.RemoteAddr(),
|
||||
ip: conn.RemoteAddr().String(),
|
||||
gn: conn,
|
||||
write: func(data []byte) error {
|
||||
|
@ -29,11 +37,16 @@ func newGNetConn(conn gnet.Conn) *Conn {
|
|||
},
|
||||
data: map[any]any{},
|
||||
}
|
||||
if index := strings.LastIndex(c.ip, ":"); index != -1 {
|
||||
c.ip = c.ip[0:index]
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// newKcpConn 创建一个处理WebSocket的连接
|
||||
func newWebsocketConn(ws *websocket.Conn, ip string) *Conn {
|
||||
return &Conn{
|
||||
remoteAddr: ws.RemoteAddr(),
|
||||
ip: ip,
|
||||
ws: ws,
|
||||
write: func(data []byte) error {
|
||||
|
@ -45,6 +58,7 @@ func newWebsocketConn(ws *websocket.Conn, ip string) *Conn {
|
|||
|
||||
// Conn 服务器连接
|
||||
type Conn struct {
|
||||
remoteAddr net.Addr
|
||||
ip string
|
||||
ws *websocket.Conn
|
||||
gn gnet.Conn
|
||||
|
@ -53,7 +67,15 @@ type Conn struct {
|
|||
data map[any]any
|
||||
}
|
||||
|
||||
func (slf *Conn) RemoteAddr() net.Addr {
|
||||
return slf.remoteAddr
|
||||
}
|
||||
|
||||
func (slf *Conn) GetID() string {
|
||||
return slf.remoteAddr.String()
|
||||
}
|
||||
|
||||
func (slf *Conn) GetIP() string {
|
||||
return slf.ip
|
||||
}
|
||||
|
||||
|
|
|
@ -16,4 +16,5 @@ var (
|
|||
ErrWebsocketIllegalMessageType = errors.New("illegal message type")
|
||||
ErrPleaseUseWebsocketHandle = errors.New("in Websocket mode, please use the RegConnectionReceiveWebsocketPacketEvent function to register")
|
||||
ErrPleaseUseOrdinaryPacketHandle = errors.New("non Websocket mode, please use the RegConnectionReceivePacketEvent function to register")
|
||||
ErrOnlySupportSocket = errors.New("only supports Socket programming")
|
||||
)
|
||||
|
|
|
@ -74,7 +74,11 @@ func (slf *event) RegConnectionOpenedEvent(handle ConnectionOpenedEventHandle) {
|
|||
}
|
||||
|
||||
func (slf *event) OnConnectionOpenedEvent(conn *Conn) {
|
||||
if len(slf.diversionMessageChannels) == 0 {
|
||||
log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()))
|
||||
} else {
|
||||
log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()), zap.Int("Node", slf.diversionConsistency.PickNode(conn.GetID())))
|
||||
}
|
||||
for _, handle := range slf.connectionOpenedEventHandles {
|
||||
handle(slf.Server, conn)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/kercylan98/minotaur/utils/hash"
|
||||
"github.com/kercylan98/minotaur/utils/log"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
@ -21,6 +22,25 @@ const (
|
|||
|
||||
type Option func(srv *Server)
|
||||
|
||||
// WithConnectPacketDiversion 通过连接数据包消息分流的方式创建服务器
|
||||
// - 连接消息分流后数据包消息将会从其他消息类型中独立出来,并且由多个消息管道及协程进行处理
|
||||
// - 默认不会进行消息分流
|
||||
// - 需要注意并发编程
|
||||
func WithConnectPacketDiversion(diversionNumber, channelSize int) Option {
|
||||
return func(srv *Server) {
|
||||
if srv.network == NetworkHttp || srv.network == NetworkGRPC {
|
||||
log.Warn("WithConnectPacketDiversion", zap.String("Network", string(srv.network)), zap.Error(ErrOnlySupportSocket))
|
||||
return
|
||||
}
|
||||
srv.diversionMessageChannels = make([]chan *message, diversionNumber)
|
||||
srv.diversionConsistency = hash.NewConsistency(3)
|
||||
for i := 0; i < diversionNumber; i++ {
|
||||
srv.diversionMessageChannels[i] = make(chan *message, channelSize)
|
||||
srv.diversionConsistency.AddNode(i + 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLS 通过安全传输层协议TLS创建服务器
|
||||
// - 支持:Http、Websocket
|
||||
func WithTLS(certFile, keyFile string) Option {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/kercylan98/minotaur/utils/hash"
|
||||
"github.com/kercylan98/minotaur/utils/log"
|
||||
"github.com/kercylan98/minotaur/utils/synchronization"
|
||||
"github.com/panjf2000/gnet"
|
||||
|
@ -70,6 +71,8 @@ type Server struct {
|
|||
multiple bool // 是否为多服务器模式下运行
|
||||
prod bool // 是否为生产模式
|
||||
core int // 消息处理核心数
|
||||
diversionMessageChannels []chan *message // 分流消息管道
|
||||
diversionConsistency *hash.Consistency // 哈希一致性分流器
|
||||
}
|
||||
|
||||
// Run 使用特定地址运行服务器
|
||||
|
@ -118,6 +121,15 @@ func (slf *Server) Run(addr string) error {
|
|||
slf.dispatchMessage(message)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for i := 0; i < len(slf.diversionMessageChannels); i++ {
|
||||
go func(channel chan *message) {
|
||||
for message := range channel {
|
||||
slf.dispatchMessage(message)
|
||||
}
|
||||
}(slf.diversionMessageChannels[i])
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -249,7 +261,7 @@ func (slf *Server) Run(addr string) error {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if !slf.supportMessageTypes[messageType] {
|
||||
if len(slf.supportMessageTypes) > 0 && !slf.supportMessageTypes[messageType] {
|
||||
panic(ErrWebsocketIllegalMessageType)
|
||||
}
|
||||
slf.PushMessage(MessageTypePacket, conn, packet, messageType)
|
||||
|
@ -311,6 +323,11 @@ func (slf *Server) IsDev() bool {
|
|||
// Shutdown 停止运行服务器
|
||||
func (slf *Server) Shutdown(err error) {
|
||||
slf.isShutdown.Store(true)
|
||||
if len(slf.diversionMessageChannels) > 0 {
|
||||
for i := 0; i < len(slf.diversionMessageChannels); i++ {
|
||||
close(slf.diversionMessageChannels[i])
|
||||
}
|
||||
}
|
||||
if slf.initMessageChannel {
|
||||
if slf.gServer != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
|
@ -364,7 +381,12 @@ func (slf *Server) PushMessage(messageType MessageType, attrs ...any) {
|
|||
msg := slf.messagePool.Get()
|
||||
msg.t = messageType
|
||||
msg.attrs = attrs
|
||||
if messageType == MessageTypePacket && len(slf.diversionMessageChannels) > 0 {
|
||||
conn := attrs[0].(*Conn)
|
||||
slf.diversionMessageChannels[slf.diversionConsistency.PickNode(conn.ip)] <- msg
|
||||
} else {
|
||||
slf.messageChannel <- msg
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchMessage 消息分发
|
||||
|
@ -381,6 +403,9 @@ func (slf *Server) dispatchMessage(msg *message) {
|
|||
case MessageTypePacket:
|
||||
if slf.network == NetworkWebsocket {
|
||||
conn, packet, messageType := msg.t.deconstructWebSocketPacket(msg.attrs...)
|
||||
if slf.diversionConsistency != nil {
|
||||
slf.diversionConsistency.PickNode(conn)
|
||||
}
|
||||
slf.OnConnectionReceiveWebsocketPacketEvent(conn, packet, messageType)
|
||||
} else {
|
||||
conn, packet := msg.t.deconstructPacket(msg.attrs...)
|
||||
|
|
|
@ -8,11 +8,17 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
func NewConsistency(replicas int) *Consistency {
|
||||
return &Consistency{
|
||||
replicas: replicas,
|
||||
}
|
||||
}
|
||||
|
||||
// Consistency 一致性哈希生成
|
||||
//
|
||||
// https://blog.csdn.net/zhpCSDN921011/article/details/126845397
|
||||
type Consistency struct {
|
||||
Replicas int // 虚拟节点的数量
|
||||
replicas int // 虚拟节点的数量
|
||||
keys []int // 所有虚拟节点的哈希值
|
||||
hashMap map[int]int // 虚拟节点的哈希值: 节点(虚拟节点映射到真实节点)
|
||||
}
|
||||
|
@ -22,11 +28,11 @@ func (slf *Consistency) AddNode(keys ...int) {
|
|||
if slf.hashMap == nil {
|
||||
slf.hashMap = map[int]int{}
|
||||
}
|
||||
if slf.Replicas == 0 {
|
||||
slf.Replicas = 3
|
||||
if slf.replicas == 0 {
|
||||
slf.replicas = 3
|
||||
}
|
||||
for _, key := range keys {
|
||||
for i := 0; i < slf.Replicas; i++ {
|
||||
for i := 0; i < slf.replicas; i++ {
|
||||
// 计算虚拟节点哈希值
|
||||
hash := int(crc32.ChecksumIEEE([]byte(strconv.Itoa(i) + strconv.Itoa(key))))
|
||||
|
||||
|
|
Loading…
Reference in New Issue