diff --git a/go.mod b/go.mod index 04a3d53..f44d03a 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/xtaci/kcp-go/v5 v5.6.3 go.uber.org/atomic v1.10.0 go.uber.org/zap v1.25.0 + golang.org/x/time v0.3.0 google.golang.org/grpc v1.57.0 ) @@ -50,9 +51,7 @@ require ( github.com/nats-io/nkeys v0.4.4 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect - github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/sasha-s/go-deadlock v0.3.1 // indirect github.com/smarty/assertions v1.15.0 // indirect github.com/templexxx/cpu v0.1.0 // indirect github.com/templexxx/xorsimd v0.4.2 // indirect diff --git a/go.sum b/go.sum index 143cb43..bb33f3b 100644 --- a/go.sum +++ b/go.sum @@ -134,8 +134,6 @@ github.com/panjf2000/gnet v1.6.7/go.mod h1:KcOU7QsCaCBjeD5kyshBIamG3d9kAQtlob4Y0 github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= -github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ= -github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -147,8 +145,6 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/sasha-s/go-deadlock v0.3.1 h1:sqv7fDNShgjcaxkO0JNcOAlr8B9+cV5Ey/OB71efZx0= -github.com/sasha-s/go-deadlock v0.3.1/go.mod h1:F73l+cr82YSh10GxyRI6qZiCgK64VaZjwesgfQ1/iLM= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= diff --git a/server/event.go b/server/event.go index d5f0dea..cf38d80 100644 --- a/server/event.go +++ b/server/event.go @@ -26,6 +26,7 @@ type ConnectionWritePacketBeforeEventHandle func(srv *Server, conn *Conn, packet type ShuntChannelCreatedEventHandle func(srv *Server, guid int64) type ShuntChannelClosedEventHandle func(srv *Server, guid int64) type ConnectionPacketPreprocessEventHandle func(srv *Server, conn *Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) +type MessageExecBeforeEventHandle func(srv *Server, message *Message) bool func newEvent(srv *Server) *event { return &event{ @@ -44,6 +45,7 @@ func newEvent(srv *Server) *event { shuntChannelCreatedEventHandles: slice.NewPriority[ShuntChannelCreatedEventHandle](), shuntChannelClosedEventHandles: slice.NewPriority[ShuntChannelClosedEventHandle](), connectionPacketPreprocessEventHandles: slice.NewPriority[ConnectionPacketPreprocessEventHandle](), + messageExecBeforeEventHandles: slice.NewPriority[MessageExecBeforeEventHandle](), } } @@ -63,6 +65,7 @@ type event struct { shuntChannelCreatedEventHandles *slice.Priority[ShuntChannelCreatedEventHandle] shuntChannelClosedEventHandles *slice.Priority[ShuntChannelClosedEventHandle] connectionPacketPreprocessEventHandles *slice.Priority[ConnectionPacketPreprocessEventHandle] + messageExecBeforeEventHandles *slice.Priority[MessageExecBeforeEventHandle] consoleCommandEventHandles map[string]*slice.Priority[ConsoleCommandEventHandle] consoleCommandEventHandleInitOnce sync.Once @@ -353,6 +356,34 @@ func (slf *event) OnConnectionPacketPreprocessEvent(conn *Conn, packet []byte, u return abort } +// RegMessageExecBeforeEvent 在处理消息前将立刻执行被注册的事件处理函数 +// - 当返回 true 时,将继续执行后续的消息处理函数,否则将不会执行后续的消息处理函数,并且该消息将被丢弃 +// +// 适用于限流等场景 +func (slf *event) RegMessageExecBeforeEvent(handle MessageExecBeforeEventHandle, priority ...int) { + slf.messageExecBeforeEventHandles.Append(handle, slice.GetValue(priority, 0)) + log.Info("Server", log.String("RegEvent", runtimes.CurrentRunningFuncName()), log.String("handle", reflect.TypeOf(handle).String())) +} + +// OnMessageExecBeforeEvent 执行消息处理前的事件处理函数 +func (slf *event) OnMessageExecBeforeEvent(message *Message) bool { + if slf.messageExecBeforeEventHandles.Len() == 0 { + return true + } + var result = true + defer func() { + if err := recover(); err != nil { + log.Error("Server", log.String("OnMessageExecBeforeEvent", fmt.Sprintf("%v", err))) + debug.PrintStack() + } + }() + slf.messageExecBeforeEventHandles.RangeValue(func(index int, value MessageExecBeforeEventHandle) bool { + result = value(slf.Server, message) + return result + }) + return result +} + func (slf *event) check() { switch slf.network { case NetworkHttp, NetworkGRPC, NetworkNone: diff --git a/server/message.go b/server/message.go index 64d8ed0..7f027e9 100644 --- a/server/message.go +++ b/server/message.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/kercylan98/minotaur/utils/hash" "reflect" ) @@ -64,6 +65,11 @@ type ( MessageErrorAction byte ) +// HasMessageType 检查是否存在指定的消息类型 +func HasMessageType(mt MessageType) bool { + return hash.Exist(messageNames, mt) +} + func (slf MessageErrorAction) String() string { return messageErrorActionNames[slf] } @@ -74,6 +80,11 @@ type Message struct { attrs []any // 消息属性 } +// MessageType 返回消息类型 +func (slf *Message) MessageType() MessageType { + return slf.t +} + // String 返回消息的字符串表示 func (slf *Message) String() string { var attrs = make([]any, 0, len(slf.attrs)) @@ -106,6 +117,13 @@ func (slf MessageType) String() string { return messageNames[slf] } +// GetPacketMessageAttrs 获取消息中的数据包属性 +func (slf *Message) GetPacketMessageAttrs() (conn *Conn, packet []byte) { + conn = slf.attrs[0].(*Conn) + packet = slf.attrs[1].([]byte) + return +} + // PushPacketMessage 向特定服务器中推送 MessageTypePacket 消息 func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ...any) { msg := srv.messagePool.Get() @@ -114,6 +132,13 @@ func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ... srv.pushMessage(msg) } +// GetErrorMessageAttrs 获取消息中的错误属性 +func (slf *Message) GetErrorMessageAttrs() (err error, action MessageErrorAction) { + err = slf.attrs[0].(error) + action = slf.attrs[1].(MessageErrorAction) + return +} + // PushErrorMessage 向特定服务器中推送 MessageTypeError 消息 func PushErrorMessage(srv *Server, err error, action MessageErrorAction, mark ...any) { msg := srv.messagePool.Get() @@ -122,6 +147,13 @@ func PushErrorMessage(srv *Server, err error, action MessageErrorAction, mark .. srv.pushMessage(msg) } +// GetCrossMessageAttrs 获取消息中的跨服属性 +func (slf *Message) GetCrossMessageAttrs() (serverId int64, packet []byte) { + serverId = slf.attrs[0].(int64) + packet = slf.attrs[1].([]byte) + return +} + // PushCrossMessage 向特定服务器中推送 MessageTypeCross 消息 func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []byte, mark ...any) { if serverId == srv.id { @@ -141,6 +173,12 @@ func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []by } } +// GetTickerMessageAttrs 获取消息中的定时器属性 +func (slf *Message) GetTickerMessageAttrs() (caller func()) { + caller = slf.attrs[0].(func()) + return +} + // PushTickerMessage 向特定服务器中推送 MessageTypeTicker 消息 func PushTickerMessage(srv *Server, caller func(), mark ...any) { msg := srv.messagePool.Get() @@ -149,6 +187,13 @@ func PushTickerMessage(srv *Server, caller func(), mark ...any) { srv.pushMessage(msg) } +// GetAsyncMessageAttrs 获取消息中的异步消息属性 +func (slf *Message) GetAsyncMessageAttrs() (caller func() error, callback func(err error), hasCallback bool) { + caller = slf.attrs[0].(func() error) + callback, hasCallback = slf.attrs[1].(func(err error)) + return +} + // PushAsyncMessage 向特定服务器中推送 MessageTypeAsync 消息 // - 异步消息将在服务器的异步消息队列中进行处理,处理完成 caller 的阻塞操作后,将会通过系统消息执行 callback 函数 // - callback 函数将在异步消息处理完成后进行调用,无论过程是否产生 err,都将被执行,允许为 nil @@ -162,6 +207,12 @@ func PushAsyncMessage(srv *Server, caller func() error, callback func(err error) srv.pushMessage(msg) } +// GetSystemMessageAttrs 获取消息中的系统消息属性 +func (slf *Message) GetSystemMessageAttrs() (handle func()) { + handle = slf.attrs[0].(func()) + return +} + // PushSystemMessage 向特定服务器中推送 MessageTypeSystem 消息 func PushSystemMessage(srv *Server, handle func(), mark ...any) { msg := srv.messagePool.Get() diff --git a/server/server.go b/server/server.go index 79263b9..0405a72 100644 --- a/server/server.go +++ b/server/server.go @@ -102,18 +102,17 @@ type Server struct { } // Run 使用特定地址运行服务器 -// -// server.NetworkTcp (addr:":8888") -// server.NetworkTcp4 (addr:":8888") -// server.NetworkTcp6 (addr:":8888") -// server.NetworkUdp (addr:":8888") -// server.NetworkUdp4 (addr:":8888") -// server.NetworkUdp6 (addr:":8888") -// server.NetworkUnix (addr:"socketPath") -// server.NetworkHttp (addr:":8888") -// server.NetworkWebsocket (addr:":8888/ws") -// server.NetworkKcp (addr:":8888") -// server.NetworkNone (addr:"") +// - server.NetworkTcp (addr:":8888") +// - server.NetworkTcp4 (addr:":8888") +// - server.NetworkTcp6 (addr:":8888") +// - server.NetworkUdp (addr:":8888") +// - server.NetworkUdp4 (addr:":8888") +// - server.NetworkUdp6 (addr:":8888") +// - server.NetworkUnix (addr:"socketPath") +// - server.NetworkHttp (addr:":8888") +// - server.NetworkWebsocket (addr:":8888/ws") +// - server.NetworkKcp (addr:":8888") +// - server.NetworkNone (addr:"") func (slf *Server) Run(addr string) error { if slf.network == NetworkNone { addr = "-" @@ -365,6 +364,16 @@ func (slf *Server) RunNone() error { return slf.Run(str.None) } +// Context 获取服务器上下文 +func (slf *Server) Context() context.Context { + return slf.ctx +} + +// TimeoutContext 获取服务器超时上下文,context.WithTimeout 的简写 +func (slf *Server) TimeoutContext(timeout time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(slf.ctx, timeout) +} + // GetOnlineCount 获取在线人数 func (slf *Server) GetOnlineCount() int { return slf.online.Size() @@ -546,7 +555,7 @@ func (slf *Server) pushMessage(message *Message) { slf.messagePool.Release(message) return } - if slf.isShutdown.Load() { + if slf.isShutdown.Load() || !slf.OnMessageExecBeforeEvent(message) { return } if slf.shuntChannels != nil && message.t == MessageTypePacket { @@ -632,13 +641,12 @@ func (slf *Server) dispatchMessage(msg *Message) { var attrs = msg.attrs switch msg.t { case MessageTypePacket: - var conn = attrs[0].(*Conn) - var packet = attrs[1].([]byte) + var conn, packet = msg.GetPacketMessageAttrs() if !slf.OnConnectionPacketPreprocessEvent(conn, packet, func(newPacket []byte) { packet = newPacket }) { slf.OnConnectionReceivePacketEvent(conn, packet) } case MessageTypeError: - err, action := attrs[0].(error), attrs[1].(MessageErrorAction) + var err, action = msg.GetErrorMessageAttrs() switch action { case MessageErrorActionNone: log.Panic("Server", log.Err(err)) @@ -648,12 +656,11 @@ func (slf *Server) dispatchMessage(msg *Message) { log.Warn("Server", log.String("not support message error action", action.String())) } case MessageTypeCross: - slf.OnReceiveCrossPacketEvent(attrs[0].(int64), attrs[1].([]byte)) + slf.OnReceiveCrossPacketEvent(msg.GetCrossMessageAttrs()) case MessageTypeTicker: - attrs[0].(func())() + msg.GetTickerMessageAttrs()() case MessageTypeAsync: - handle := attrs[0].(func() error) - callback, cb := attrs[1].(func(err error)) + handle, callback, cb := msg.GetAsyncMessageAttrs() if err := slf.ants.Submit(func() { defer func() { if err := recover(); err != nil { @@ -688,10 +695,10 @@ func (slf *Server) dispatchMessage(msg *Message) { }); err != nil { panic(err) } - case MessageTypeAsyncCallback: + case MessageTypeAsyncCallback: // 特殊类型 attrs[0].(func())() case MessageTypeSystem: - attrs[0].(func())() + msg.GetSystemMessageAttrs()() default: log.Warn("Server", log.String("not support message type", msg.t.String())) } diff --git a/server/server_test.go b/server/server_test.go index 1d76b07..3b056e0 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -5,13 +5,23 @@ import ( "github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server/client" "github.com/kercylan98/minotaur/utils/times" + "golang.org/x/time/rate" "sync/atomic" "testing" "time" ) func TestNew(t *testing.T) { + limiter := rate.NewLimiter(rate.Every(time.Second), 100) srv := server.New(server.NetworkWebsocket, server.WithMessageBufferSize(1024*1024), server.WithPProf()) + srv.RegMessageExecBeforeEvent(func(srv *server.Server, message *server.Message) bool { + t, c := srv.TimeoutContext(time.Second * 5) + defer c() + if err := limiter.Wait(t); err != nil { + return false + } + return true + }) srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { conn.Write(packet) })