feat: server 包新增 Server.RegMessageExecBeforeEvent 函数,支持在消息执行前进行处理,适用于限流等场景

This commit is contained in:
kercylan98 2023-09-05 10:52:09 +08:00
parent add1e4bc8c
commit 0297c4444a
6 changed files with 122 additions and 28 deletions

3
go.mod
View File

@ -21,6 +21,7 @@ require (
github.com/xtaci/kcp-go/v5 v5.6.3 github.com/xtaci/kcp-go/v5 v5.6.3
go.uber.org/atomic v1.10.0 go.uber.org/atomic v1.10.0
go.uber.org/zap v1.25.0 go.uber.org/zap v1.25.0
golang.org/x/time v0.3.0
google.golang.org/grpc v1.57.0 google.golang.org/grpc v1.57.0
) )
@ -50,9 +51,7 @@ require (
github.com/nats-io/nkeys v0.4.4 // indirect github.com/nats-io/nkeys v0.4.4 // indirect
github.com/nats-io/nuid v1.0.1 // indirect github.com/nats-io/nuid v1.0.1 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/sasha-s/go-deadlock v0.3.1 // indirect
github.com/smarty/assertions v1.15.0 // indirect github.com/smarty/assertions v1.15.0 // indirect
github.com/templexxx/cpu v0.1.0 // indirect github.com/templexxx/cpu v0.1.0 // indirect
github.com/templexxx/xorsimd v0.4.2 // indirect github.com/templexxx/xorsimd v0.4.2 // indirect

4
go.sum
View File

@ -134,8 +134,6 @@ github.com/panjf2000/gnet v1.6.7/go.mod h1:KcOU7QsCaCBjeD5kyshBIamG3d9kAQtlob4Y0
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ=
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
@ -147,8 +145,6 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/sasha-s/go-deadlock v0.3.1 h1:sqv7fDNShgjcaxkO0JNcOAlr8B9+cV5Ey/OB71efZx0=
github.com/sasha-s/go-deadlock v0.3.1/go.mod h1:F73l+cr82YSh10GxyRI6qZiCgK64VaZjwesgfQ1/iLM=
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=

View File

@ -26,6 +26,7 @@ type ConnectionWritePacketBeforeEventHandle func(srv *Server, conn *Conn, packet
type ShuntChannelCreatedEventHandle func(srv *Server, guid int64) type ShuntChannelCreatedEventHandle func(srv *Server, guid int64)
type ShuntChannelClosedEventHandle func(srv *Server, guid int64) type ShuntChannelClosedEventHandle func(srv *Server, guid int64)
type ConnectionPacketPreprocessEventHandle func(srv *Server, conn *Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) type ConnectionPacketPreprocessEventHandle func(srv *Server, conn *Conn, packet []byte, abort func(), usePacket func(newPacket []byte))
type MessageExecBeforeEventHandle func(srv *Server, message *Message) bool
func newEvent(srv *Server) *event { func newEvent(srv *Server) *event {
return &event{ return &event{
@ -44,6 +45,7 @@ func newEvent(srv *Server) *event {
shuntChannelCreatedEventHandles: slice.NewPriority[ShuntChannelCreatedEventHandle](), shuntChannelCreatedEventHandles: slice.NewPriority[ShuntChannelCreatedEventHandle](),
shuntChannelClosedEventHandles: slice.NewPriority[ShuntChannelClosedEventHandle](), shuntChannelClosedEventHandles: slice.NewPriority[ShuntChannelClosedEventHandle](),
connectionPacketPreprocessEventHandles: slice.NewPriority[ConnectionPacketPreprocessEventHandle](), connectionPacketPreprocessEventHandles: slice.NewPriority[ConnectionPacketPreprocessEventHandle](),
messageExecBeforeEventHandles: slice.NewPriority[MessageExecBeforeEventHandle](),
} }
} }
@ -63,6 +65,7 @@ type event struct {
shuntChannelCreatedEventHandles *slice.Priority[ShuntChannelCreatedEventHandle] shuntChannelCreatedEventHandles *slice.Priority[ShuntChannelCreatedEventHandle]
shuntChannelClosedEventHandles *slice.Priority[ShuntChannelClosedEventHandle] shuntChannelClosedEventHandles *slice.Priority[ShuntChannelClosedEventHandle]
connectionPacketPreprocessEventHandles *slice.Priority[ConnectionPacketPreprocessEventHandle] connectionPacketPreprocessEventHandles *slice.Priority[ConnectionPacketPreprocessEventHandle]
messageExecBeforeEventHandles *slice.Priority[MessageExecBeforeEventHandle]
consoleCommandEventHandles map[string]*slice.Priority[ConsoleCommandEventHandle] consoleCommandEventHandles map[string]*slice.Priority[ConsoleCommandEventHandle]
consoleCommandEventHandleInitOnce sync.Once consoleCommandEventHandleInitOnce sync.Once
@ -353,6 +356,34 @@ func (slf *event) OnConnectionPacketPreprocessEvent(conn *Conn, packet []byte, u
return abort return abort
} }
// RegMessageExecBeforeEvent 在处理消息前将立刻执行被注册的事件处理函数
// - 当返回 true 时,将继续执行后续的消息处理函数,否则将不会执行后续的消息处理函数,并且该消息将被丢弃
//
// 适用于限流等场景
func (slf *event) RegMessageExecBeforeEvent(handle MessageExecBeforeEventHandle, priority ...int) {
slf.messageExecBeforeEventHandles.Append(handle, slice.GetValue(priority, 0))
log.Info("Server", log.String("RegEvent", runtimes.CurrentRunningFuncName()), log.String("handle", reflect.TypeOf(handle).String()))
}
// OnMessageExecBeforeEvent 执行消息处理前的事件处理函数
func (slf *event) OnMessageExecBeforeEvent(message *Message) bool {
if slf.messageExecBeforeEventHandles.Len() == 0 {
return true
}
var result = true
defer func() {
if err := recover(); err != nil {
log.Error("Server", log.String("OnMessageExecBeforeEvent", fmt.Sprintf("%v", err)))
debug.PrintStack()
}
}()
slf.messageExecBeforeEventHandles.RangeValue(func(index int, value MessageExecBeforeEventHandle) bool {
result = value(slf.Server, message)
return result
})
return result
}
func (slf *event) check() { func (slf *event) check() {
switch slf.network { switch slf.network {
case NetworkHttp, NetworkGRPC, NetworkNone: case NetworkHttp, NetworkGRPC, NetworkNone:

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/kercylan98/minotaur/utils/hash"
"reflect" "reflect"
) )
@ -64,6 +65,11 @@ type (
MessageErrorAction byte MessageErrorAction byte
) )
// HasMessageType 检查是否存在指定的消息类型
func HasMessageType(mt MessageType) bool {
return hash.Exist(messageNames, mt)
}
func (slf MessageErrorAction) String() string { func (slf MessageErrorAction) String() string {
return messageErrorActionNames[slf] return messageErrorActionNames[slf]
} }
@ -74,6 +80,11 @@ type Message struct {
attrs []any // 消息属性 attrs []any // 消息属性
} }
// MessageType 返回消息类型
func (slf *Message) MessageType() MessageType {
return slf.t
}
// String 返回消息的字符串表示 // String 返回消息的字符串表示
func (slf *Message) String() string { func (slf *Message) String() string {
var attrs = make([]any, 0, len(slf.attrs)) var attrs = make([]any, 0, len(slf.attrs))
@ -106,6 +117,13 @@ func (slf MessageType) String() string {
return messageNames[slf] return messageNames[slf]
} }
// GetPacketMessageAttrs 获取消息中的数据包属性
func (slf *Message) GetPacketMessageAttrs() (conn *Conn, packet []byte) {
conn = slf.attrs[0].(*Conn)
packet = slf.attrs[1].([]byte)
return
}
// PushPacketMessage 向特定服务器中推送 MessageTypePacket 消息 // PushPacketMessage 向特定服务器中推送 MessageTypePacket 消息
func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ...any) { func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ...any) {
msg := srv.messagePool.Get() msg := srv.messagePool.Get()
@ -114,6 +132,13 @@ func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ...
srv.pushMessage(msg) srv.pushMessage(msg)
} }
// GetErrorMessageAttrs 获取消息中的错误属性
func (slf *Message) GetErrorMessageAttrs() (err error, action MessageErrorAction) {
err = slf.attrs[0].(error)
action = slf.attrs[1].(MessageErrorAction)
return
}
// PushErrorMessage 向特定服务器中推送 MessageTypeError 消息 // PushErrorMessage 向特定服务器中推送 MessageTypeError 消息
func PushErrorMessage(srv *Server, err error, action MessageErrorAction, mark ...any) { func PushErrorMessage(srv *Server, err error, action MessageErrorAction, mark ...any) {
msg := srv.messagePool.Get() msg := srv.messagePool.Get()
@ -122,6 +147,13 @@ func PushErrorMessage(srv *Server, err error, action MessageErrorAction, mark ..
srv.pushMessage(msg) srv.pushMessage(msg)
} }
// GetCrossMessageAttrs 获取消息中的跨服属性
func (slf *Message) GetCrossMessageAttrs() (serverId int64, packet []byte) {
serverId = slf.attrs[0].(int64)
packet = slf.attrs[1].([]byte)
return
}
// PushCrossMessage 向特定服务器中推送 MessageTypeCross 消息 // PushCrossMessage 向特定服务器中推送 MessageTypeCross 消息
func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []byte, mark ...any) { func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []byte, mark ...any) {
if serverId == srv.id { if serverId == srv.id {
@ -141,6 +173,12 @@ func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []by
} }
} }
// GetTickerMessageAttrs 获取消息中的定时器属性
func (slf *Message) GetTickerMessageAttrs() (caller func()) {
caller = slf.attrs[0].(func())
return
}
// PushTickerMessage 向特定服务器中推送 MessageTypeTicker 消息 // PushTickerMessage 向特定服务器中推送 MessageTypeTicker 消息
func PushTickerMessage(srv *Server, caller func(), mark ...any) { func PushTickerMessage(srv *Server, caller func(), mark ...any) {
msg := srv.messagePool.Get() msg := srv.messagePool.Get()
@ -149,6 +187,13 @@ func PushTickerMessage(srv *Server, caller func(), mark ...any) {
srv.pushMessage(msg) srv.pushMessage(msg)
} }
// GetAsyncMessageAttrs 获取消息中的异步消息属性
func (slf *Message) GetAsyncMessageAttrs() (caller func() error, callback func(err error), hasCallback bool) {
caller = slf.attrs[0].(func() error)
callback, hasCallback = slf.attrs[1].(func(err error))
return
}
// PushAsyncMessage 向特定服务器中推送 MessageTypeAsync 消息 // PushAsyncMessage 向特定服务器中推送 MessageTypeAsync 消息
// - 异步消息将在服务器的异步消息队列中进行处理,处理完成 caller 的阻塞操作后,将会通过系统消息执行 callback 函数 // - 异步消息将在服务器的异步消息队列中进行处理,处理完成 caller 的阻塞操作后,将会通过系统消息执行 callback 函数
// - callback 函数将在异步消息处理完成后进行调用,无论过程是否产生 err都将被执行允许为 nil // - callback 函数将在异步消息处理完成后进行调用,无论过程是否产生 err都将被执行允许为 nil
@ -162,6 +207,12 @@ func PushAsyncMessage(srv *Server, caller func() error, callback func(err error)
srv.pushMessage(msg) srv.pushMessage(msg)
} }
// GetSystemMessageAttrs 获取消息中的系统消息属性
func (slf *Message) GetSystemMessageAttrs() (handle func()) {
handle = slf.attrs[0].(func())
return
}
// PushSystemMessage 向特定服务器中推送 MessageTypeSystem 消息 // PushSystemMessage 向特定服务器中推送 MessageTypeSystem 消息
func PushSystemMessage(srv *Server, handle func(), mark ...any) { func PushSystemMessage(srv *Server, handle func(), mark ...any) {
msg := srv.messagePool.Get() msg := srv.messagePool.Get()

View File

@ -102,18 +102,17 @@ type Server struct {
} }
// Run 使用特定地址运行服务器 // Run 使用特定地址运行服务器
// // - server.NetworkTcp (addr:":8888")
// server.NetworkTcp (addr:":8888") // - server.NetworkTcp4 (addr:":8888")
// server.NetworkTcp4 (addr:":8888") // - server.NetworkTcp6 (addr:":8888")
// server.NetworkTcp6 (addr:":8888") // - server.NetworkUdp (addr:":8888")
// server.NetworkUdp (addr:":8888") // - server.NetworkUdp4 (addr:":8888")
// server.NetworkUdp4 (addr:":8888") // - server.NetworkUdp6 (addr:":8888")
// server.NetworkUdp6 (addr:":8888") // - server.NetworkUnix (addr:"socketPath")
// server.NetworkUnix (addr:"socketPath") // - server.NetworkHttp (addr:":8888")
// server.NetworkHttp (addr:":8888") // - server.NetworkWebsocket (addr:":8888/ws")
// server.NetworkWebsocket (addr:":8888/ws") // - server.NetworkKcp (addr:":8888")
// server.NetworkKcp (addr:":8888") // - server.NetworkNone (addr:"")
// server.NetworkNone (addr:"")
func (slf *Server) Run(addr string) error { func (slf *Server) Run(addr string) error {
if slf.network == NetworkNone { if slf.network == NetworkNone {
addr = "-" addr = "-"
@ -365,6 +364,16 @@ func (slf *Server) RunNone() error {
return slf.Run(str.None) return slf.Run(str.None)
} }
// Context 获取服务器上下文
func (slf *Server) Context() context.Context {
return slf.ctx
}
// TimeoutContext 获取服务器超时上下文context.WithTimeout 的简写
func (slf *Server) TimeoutContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(slf.ctx, timeout)
}
// GetOnlineCount 获取在线人数 // GetOnlineCount 获取在线人数
func (slf *Server) GetOnlineCount() int { func (slf *Server) GetOnlineCount() int {
return slf.online.Size() return slf.online.Size()
@ -546,7 +555,7 @@ func (slf *Server) pushMessage(message *Message) {
slf.messagePool.Release(message) slf.messagePool.Release(message)
return return
} }
if slf.isShutdown.Load() { if slf.isShutdown.Load() || !slf.OnMessageExecBeforeEvent(message) {
return return
} }
if slf.shuntChannels != nil && message.t == MessageTypePacket { if slf.shuntChannels != nil && message.t == MessageTypePacket {
@ -632,13 +641,12 @@ func (slf *Server) dispatchMessage(msg *Message) {
var attrs = msg.attrs var attrs = msg.attrs
switch msg.t { switch msg.t {
case MessageTypePacket: case MessageTypePacket:
var conn = attrs[0].(*Conn) var conn, packet = msg.GetPacketMessageAttrs()
var packet = attrs[1].([]byte)
if !slf.OnConnectionPacketPreprocessEvent(conn, packet, func(newPacket []byte) { packet = newPacket }) { if !slf.OnConnectionPacketPreprocessEvent(conn, packet, func(newPacket []byte) { packet = newPacket }) {
slf.OnConnectionReceivePacketEvent(conn, packet) slf.OnConnectionReceivePacketEvent(conn, packet)
} }
case MessageTypeError: case MessageTypeError:
err, action := attrs[0].(error), attrs[1].(MessageErrorAction) var err, action = msg.GetErrorMessageAttrs()
switch action { switch action {
case MessageErrorActionNone: case MessageErrorActionNone:
log.Panic("Server", log.Err(err)) log.Panic("Server", log.Err(err))
@ -648,12 +656,11 @@ func (slf *Server) dispatchMessage(msg *Message) {
log.Warn("Server", log.String("not support message error action", action.String())) log.Warn("Server", log.String("not support message error action", action.String()))
} }
case MessageTypeCross: case MessageTypeCross:
slf.OnReceiveCrossPacketEvent(attrs[0].(int64), attrs[1].([]byte)) slf.OnReceiveCrossPacketEvent(msg.GetCrossMessageAttrs())
case MessageTypeTicker: case MessageTypeTicker:
attrs[0].(func())() msg.GetTickerMessageAttrs()()
case MessageTypeAsync: case MessageTypeAsync:
handle := attrs[0].(func() error) handle, callback, cb := msg.GetAsyncMessageAttrs()
callback, cb := attrs[1].(func(err error))
if err := slf.ants.Submit(func() { if err := slf.ants.Submit(func() {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
@ -688,10 +695,10 @@ func (slf *Server) dispatchMessage(msg *Message) {
}); err != nil { }); err != nil {
panic(err) panic(err)
} }
case MessageTypeAsyncCallback: case MessageTypeAsyncCallback: // 特殊类型
attrs[0].(func())() attrs[0].(func())()
case MessageTypeSystem: case MessageTypeSystem:
attrs[0].(func())() msg.GetSystemMessageAttrs()()
default: default:
log.Warn("Server", log.String("not support message type", msg.t.String())) log.Warn("Server", log.String("not support message type", msg.t.String()))
} }

View File

@ -5,13 +5,23 @@ import (
"github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/server/client" "github.com/kercylan98/minotaur/server/client"
"github.com/kercylan98/minotaur/utils/times" "github.com/kercylan98/minotaur/utils/times"
"golang.org/x/time/rate"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
limiter := rate.NewLimiter(rate.Every(time.Second), 100)
srv := server.New(server.NetworkWebsocket, server.WithMessageBufferSize(1024*1024), server.WithPProf()) srv := server.New(server.NetworkWebsocket, server.WithMessageBufferSize(1024*1024), server.WithPProf())
srv.RegMessageExecBeforeEvent(func(srv *server.Server, message *server.Message) bool {
t, c := srv.TimeoutContext(time.Second * 5)
defer c()
if err := limiter.Wait(t); err != nil {
return false
}
return true
})
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
conn.Write(packet) conn.Write(packet)
}) })