修复Websocket消息类型过滤不设置时无法接收数据包的问题,服务器增加连接分流功能
This commit is contained in:
parent
df4aa30743
commit
926b69bee1
|
@ -0,0 +1,11 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "github.com/kercylan98/minotaur/server"
|
||||||
|
|
||||||
|
// 无意义的测试main入口
|
||||||
|
func main() {
|
||||||
|
srv := server.New(server.NetworkWebsocket, server.WithConnectPacketDiversion(3, 2))
|
||||||
|
if err := srv.Run(":8999"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,38 +4,51 @@ import (
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/panjf2000/gnet"
|
"github.com/panjf2000/gnet"
|
||||||
"github.com/xtaci/kcp-go/v5"
|
"github.com/xtaci/kcp-go/v5"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// newKcpConn 创建一个处理KCP的连接
|
// newKcpConn 创建一个处理KCP的连接
|
||||||
func newKcpConn(session *kcp.UDPSession) *Conn {
|
func newKcpConn(session *kcp.UDPSession) *Conn {
|
||||||
return &Conn{
|
c := &Conn{
|
||||||
ip: session.RemoteAddr().String(),
|
remoteAddr: session.RemoteAddr(),
|
||||||
kcp: session,
|
ip: session.RemoteAddr().String(),
|
||||||
|
kcp: session,
|
||||||
write: func(data []byte) error {
|
write: func(data []byte) error {
|
||||||
_, err := session.Write(data)
|
_, err := session.Write(data)
|
||||||
return err
|
return err
|
||||||
},
|
},
|
||||||
data: map[any]any{},
|
data: map[any]any{},
|
||||||
}
|
}
|
||||||
|
if index := strings.LastIndex(c.ip, ":"); index != -1 {
|
||||||
|
c.ip = c.ip[0:index]
|
||||||
|
}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// newKcpConn 创建一个处理GNet的连接
|
// newKcpConn 创建一个处理GNet的连接
|
||||||
func newGNetConn(conn gnet.Conn) *Conn {
|
func newGNetConn(conn gnet.Conn) *Conn {
|
||||||
return &Conn{
|
c := &Conn{
|
||||||
ip: conn.RemoteAddr().String(),
|
remoteAddr: conn.RemoteAddr(),
|
||||||
gn: conn,
|
ip: conn.RemoteAddr().String(),
|
||||||
|
gn: conn,
|
||||||
write: func(data []byte) error {
|
write: func(data []byte) error {
|
||||||
return conn.AsyncWrite(data)
|
return conn.AsyncWrite(data)
|
||||||
},
|
},
|
||||||
data: map[any]any{},
|
data: map[any]any{},
|
||||||
}
|
}
|
||||||
|
if index := strings.LastIndex(c.ip, ":"); index != -1 {
|
||||||
|
c.ip = c.ip[0:index]
|
||||||
|
}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// newKcpConn 创建一个处理WebSocket的连接
|
// newKcpConn 创建一个处理WebSocket的连接
|
||||||
func newWebsocketConn(ws *websocket.Conn, ip string) *Conn {
|
func newWebsocketConn(ws *websocket.Conn, ip string) *Conn {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
ip: ip,
|
remoteAddr: ws.RemoteAddr(),
|
||||||
ws: ws,
|
ip: ip,
|
||||||
|
ws: ws,
|
||||||
write: func(data []byte) error {
|
write: func(data []byte) error {
|
||||||
return ws.WriteMessage(websocket.BinaryMessage, data)
|
return ws.WriteMessage(websocket.BinaryMessage, data)
|
||||||
},
|
},
|
||||||
|
@ -45,15 +58,24 @@ func newWebsocketConn(ws *websocket.Conn, ip string) *Conn {
|
||||||
|
|
||||||
// Conn 服务器连接
|
// Conn 服务器连接
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
ip string
|
remoteAddr net.Addr
|
||||||
ws *websocket.Conn
|
ip string
|
||||||
gn gnet.Conn
|
ws *websocket.Conn
|
||||||
kcp *kcp.UDPSession
|
gn gnet.Conn
|
||||||
write func(data []byte) error
|
kcp *kcp.UDPSession
|
||||||
data map[any]any
|
write func(data []byte) error
|
||||||
|
data map[any]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (slf *Conn) RemoteAddr() net.Addr {
|
||||||
|
return slf.remoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (slf *Conn) GetID() string {
|
func (slf *Conn) GetID() string {
|
||||||
|
return slf.remoteAddr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (slf *Conn) GetIP() string {
|
||||||
return slf.ip
|
return slf.ip
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,4 +16,5 @@ var (
|
||||||
ErrWebsocketIllegalMessageType = errors.New("illegal message type")
|
ErrWebsocketIllegalMessageType = errors.New("illegal message type")
|
||||||
ErrPleaseUseWebsocketHandle = errors.New("in Websocket mode, please use the RegConnectionReceiveWebsocketPacketEvent function to register")
|
ErrPleaseUseWebsocketHandle = errors.New("in Websocket mode, please use the RegConnectionReceiveWebsocketPacketEvent function to register")
|
||||||
ErrPleaseUseOrdinaryPacketHandle = errors.New("non Websocket mode, please use the RegConnectionReceivePacketEvent function to register")
|
ErrPleaseUseOrdinaryPacketHandle = errors.New("non Websocket mode, please use the RegConnectionReceivePacketEvent function to register")
|
||||||
|
ErrOnlySupportSocket = errors.New("only supports Socket programming")
|
||||||
)
|
)
|
||||||
|
|
|
@ -74,7 +74,11 @@ func (slf *event) RegConnectionOpenedEvent(handle ConnectionOpenedEventHandle) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (slf *event) OnConnectionOpenedEvent(conn *Conn) {
|
func (slf *event) OnConnectionOpenedEvent(conn *Conn) {
|
||||||
log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()))
|
if len(slf.diversionMessageChannels) == 0 {
|
||||||
|
log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()))
|
||||||
|
} else {
|
||||||
|
log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()), zap.Int("Node", slf.diversionConsistency.PickNode(conn.GetID())))
|
||||||
|
}
|
||||||
for _, handle := range slf.connectionOpenedEventHandles {
|
for _, handle := range slf.connectionOpenedEventHandles {
|
||||||
handle(slf.Server, conn)
|
handle(slf.Server, conn)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/kercylan98/minotaur/utils/hash"
|
||||||
"github.com/kercylan98/minotaur/utils/log"
|
"github.com/kercylan98/minotaur/utils/log"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
@ -21,6 +22,25 @@ const (
|
||||||
|
|
||||||
type Option func(srv *Server)
|
type Option func(srv *Server)
|
||||||
|
|
||||||
|
// WithConnectPacketDiversion 通过连接数据包消息分流的方式创建服务器
|
||||||
|
// - 连接消息分流后数据包消息将会从其他消息类型中独立出来,并且由多个消息管道及协程进行处理
|
||||||
|
// - 默认不会进行消息分流
|
||||||
|
// - 需要注意并发编程
|
||||||
|
func WithConnectPacketDiversion(diversionNumber, channelSize int) Option {
|
||||||
|
return func(srv *Server) {
|
||||||
|
if srv.network == NetworkHttp || srv.network == NetworkGRPC {
|
||||||
|
log.Warn("WithConnectPacketDiversion", zap.String("Network", string(srv.network)), zap.Error(ErrOnlySupportSocket))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srv.diversionMessageChannels = make([]chan *message, diversionNumber)
|
||||||
|
srv.diversionConsistency = hash.NewConsistency(3)
|
||||||
|
for i := 0; i < diversionNumber; i++ {
|
||||||
|
srv.diversionMessageChannels[i] = make(chan *message, channelSize)
|
||||||
|
srv.diversionConsistency.AddNode(i + 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithTLS 通过安全传输层协议TLS创建服务器
|
// WithTLS 通过安全传输层协议TLS创建服务器
|
||||||
// - 支持:Http、Websocket
|
// - 支持:Http、Websocket
|
||||||
func WithTLS(certFile, keyFile string) Option {
|
func WithTLS(certFile, keyFile string) Option {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/kercylan98/minotaur/utils/hash"
|
||||||
"github.com/kercylan98/minotaur/utils/log"
|
"github.com/kercylan98/minotaur/utils/log"
|
||||||
"github.com/kercylan98/minotaur/utils/synchronization"
|
"github.com/kercylan98/minotaur/utils/synchronization"
|
||||||
"github.com/panjf2000/gnet"
|
"github.com/panjf2000/gnet"
|
||||||
|
@ -62,14 +63,16 @@ type Server struct {
|
||||||
isShutdown atomic.Bool // 是否已关闭
|
isShutdown atomic.Bool // 是否已关闭
|
||||||
closeChannel chan struct{} // 关闭信号
|
closeChannel chan struct{} // 关闭信号
|
||||||
|
|
||||||
gServer *gNet // TCP或UDP模式下的服务器
|
gServer *gNet // TCP或UDP模式下的服务器
|
||||||
messagePool *synchronization.Pool[*message] // 消息池
|
messagePool *synchronization.Pool[*message] // 消息池
|
||||||
messagePoolSize int // 消息池大小
|
messagePoolSize int // 消息池大小
|
||||||
messageChannel chan *message // 消息管道
|
messageChannel chan *message // 消息管道
|
||||||
initMessageChannel bool // 消息管道是否已经初始化
|
initMessageChannel bool // 消息管道是否已经初始化
|
||||||
multiple bool // 是否为多服务器模式下运行
|
multiple bool // 是否为多服务器模式下运行
|
||||||
prod bool // 是否为生产模式
|
prod bool // 是否为生产模式
|
||||||
core int // 消息处理核心数
|
core int // 消息处理核心数
|
||||||
|
diversionMessageChannels []chan *message // 分流消息管道
|
||||||
|
diversionConsistency *hash.Consistency // 哈希一致性分流器
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run 使用特定地址运行服务器
|
// Run 使用特定地址运行服务器
|
||||||
|
@ -118,6 +121,15 @@ func (slf *Server) Run(addr string) error {
|
||||||
slf.dispatchMessage(message)
|
slf.dispatchMessage(message)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < len(slf.diversionMessageChannels); i++ {
|
||||||
|
go func(channel chan *message) {
|
||||||
|
for message := range channel {
|
||||||
|
slf.dispatchMessage(message)
|
||||||
|
}
|
||||||
|
}(slf.diversionMessageChannels[i])
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -249,7 +261,7 @@ func (slf *Server) Run(addr string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
if !slf.supportMessageTypes[messageType] {
|
if len(slf.supportMessageTypes) > 0 && !slf.supportMessageTypes[messageType] {
|
||||||
panic(ErrWebsocketIllegalMessageType)
|
panic(ErrWebsocketIllegalMessageType)
|
||||||
}
|
}
|
||||||
slf.PushMessage(MessageTypePacket, conn, packet, messageType)
|
slf.PushMessage(MessageTypePacket, conn, packet, messageType)
|
||||||
|
@ -311,6 +323,11 @@ func (slf *Server) IsDev() bool {
|
||||||
// Shutdown 停止运行服务器
|
// Shutdown 停止运行服务器
|
||||||
func (slf *Server) Shutdown(err error) {
|
func (slf *Server) Shutdown(err error) {
|
||||||
slf.isShutdown.Store(true)
|
slf.isShutdown.Store(true)
|
||||||
|
if len(slf.diversionMessageChannels) > 0 {
|
||||||
|
for i := 0; i < len(slf.diversionMessageChannels); i++ {
|
||||||
|
close(slf.diversionMessageChannels[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
if slf.initMessageChannel {
|
if slf.initMessageChannel {
|
||||||
if slf.gServer != nil {
|
if slf.gServer != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
@ -364,7 +381,12 @@ func (slf *Server) PushMessage(messageType MessageType, attrs ...any) {
|
||||||
msg := slf.messagePool.Get()
|
msg := slf.messagePool.Get()
|
||||||
msg.t = messageType
|
msg.t = messageType
|
||||||
msg.attrs = attrs
|
msg.attrs = attrs
|
||||||
slf.messageChannel <- msg
|
if messageType == MessageTypePacket && len(slf.diversionMessageChannels) > 0 {
|
||||||
|
conn := attrs[0].(*Conn)
|
||||||
|
slf.diversionMessageChannels[slf.diversionConsistency.PickNode(conn.ip)] <- msg
|
||||||
|
} else {
|
||||||
|
slf.messageChannel <- msg
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// dispatchMessage 消息分发
|
// dispatchMessage 消息分发
|
||||||
|
@ -381,6 +403,9 @@ func (slf *Server) dispatchMessage(msg *message) {
|
||||||
case MessageTypePacket:
|
case MessageTypePacket:
|
||||||
if slf.network == NetworkWebsocket {
|
if slf.network == NetworkWebsocket {
|
||||||
conn, packet, messageType := msg.t.deconstructWebSocketPacket(msg.attrs...)
|
conn, packet, messageType := msg.t.deconstructWebSocketPacket(msg.attrs...)
|
||||||
|
if slf.diversionConsistency != nil {
|
||||||
|
slf.diversionConsistency.PickNode(conn)
|
||||||
|
}
|
||||||
slf.OnConnectionReceiveWebsocketPacketEvent(conn, packet, messageType)
|
slf.OnConnectionReceiveWebsocketPacketEvent(conn, packet, messageType)
|
||||||
} else {
|
} else {
|
||||||
conn, packet := msg.t.deconstructPacket(msg.attrs...)
|
conn, packet := msg.t.deconstructPacket(msg.attrs...)
|
||||||
|
|
|
@ -8,11 +8,17 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func NewConsistency(replicas int) *Consistency {
|
||||||
|
return &Consistency{
|
||||||
|
replicas: replicas,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Consistency 一致性哈希生成
|
// Consistency 一致性哈希生成
|
||||||
//
|
//
|
||||||
// https://blog.csdn.net/zhpCSDN921011/article/details/126845397
|
// https://blog.csdn.net/zhpCSDN921011/article/details/126845397
|
||||||
type Consistency struct {
|
type Consistency struct {
|
||||||
Replicas int // 虚拟节点的数量
|
replicas int // 虚拟节点的数量
|
||||||
keys []int // 所有虚拟节点的哈希值
|
keys []int // 所有虚拟节点的哈希值
|
||||||
hashMap map[int]int // 虚拟节点的哈希值: 节点(虚拟节点映射到真实节点)
|
hashMap map[int]int // 虚拟节点的哈希值: 节点(虚拟节点映射到真实节点)
|
||||||
}
|
}
|
||||||
|
@ -22,11 +28,11 @@ func (slf *Consistency) AddNode(keys ...int) {
|
||||||
if slf.hashMap == nil {
|
if slf.hashMap == nil {
|
||||||
slf.hashMap = map[int]int{}
|
slf.hashMap = map[int]int{}
|
||||||
}
|
}
|
||||||
if slf.Replicas == 0 {
|
if slf.replicas == 0 {
|
||||||
slf.Replicas = 3
|
slf.replicas = 3
|
||||||
}
|
}
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
for i := 0; i < slf.Replicas; i++ {
|
for i := 0; i < slf.replicas; i++ {
|
||||||
// 计算虚拟节点哈希值
|
// 计算虚拟节点哈希值
|
||||||
hash := int(crc32.ChecksumIEEE([]byte(strconv.Itoa(i) + strconv.Itoa(key))))
|
hash := int(crc32.ChecksumIEEE([]byte(strconv.Itoa(i) + strconv.Itoa(key))))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue