diff --git a/go.mod b/go.mod index f0213c7..f44d03a 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/xtaci/kcp-go/v5 v5.6.3 go.uber.org/atomic v1.10.0 go.uber.org/zap v1.25.0 + golang.org/x/time v0.3.0 google.golang.org/grpc v1.57.0 ) diff --git a/server/client/client.go b/server/client/client.go index f24bf21..bdadcb3 100644 --- a/server/client/client.go +++ b/server/client/client.go @@ -8,7 +8,6 @@ import ( // NewClient 创建客户端 func NewClient(core Core) *Client { client := &Client{ - cond: sync.NewCond(&sync.Mutex{}), events: new(events), core: core, } @@ -24,11 +23,11 @@ func CloneClient(client *Client) *Client { type Client struct { *events core Core - cond *sync.Cond + mutex sync.Mutex packetPool *concurrent.Pool[*Packet] - packets []*Packet + packets chan *Packet - accumulate []*Packet + accumulate chan *Packet accumulation int // 积压消息数 } @@ -44,14 +43,12 @@ func (slf *Client) Run() error { }() err := <-runState if err != nil { - slf.cond.L.Lock() + slf.mutex.Lock() if slf.packetPool != nil { slf.packetPool.Close() slf.packetPool = nil } - slf.accumulate = append(slf.accumulate, slf.packets...) - slf.packets = nil - slf.cond.L.Unlock() + slf.mutex.Unlock() return err } var wait = new(sync.WaitGroup) @@ -69,12 +66,29 @@ func (slf *Client) IsConnected() bool { // Close 关闭 func (slf *Client) Close(err ...error) { + slf.mutex.Lock() + var unlock bool + defer func() { + if !unlock { + slf.mutex.Unlock() + } + }() slf.core.Close() if slf.packetPool != nil { slf.packetPool.Close() slf.packetPool = nil } + if slf.packets != nil { + close(slf.packets) + slf.packets = nil + } + if slf.accumulate != nil { + close(slf.accumulate) + slf.accumulate = nil + } slf.packets = nil + unlock = true + slf.mutex.Unlock() if len(err) > 0 { slf.OnConnectionClosedEvent(slf, err[0]) } else { @@ -96,7 +110,8 @@ func (slf *Client) Write(packet []byte, callback ...func(err error)) { // write 向连接中写入数据 // - messageType: websocket模式中指定消息类型 func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) { - if slf.packetPool == nil { + slf.mutex.Lock() + if slf.packetPool == nil || slf.packets == nil { var p = &Packet{ wst: wst, data: packet, @@ -104,27 +119,26 @@ func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) { if len(callback) > 0 { p.callback = callback[0] } - slf.cond.L.Lock() - slf.accumulate = append(slf.accumulate, p) + if slf.accumulate == nil { + slf.accumulate = make(chan *Packet, 1024*10) + } + slf.accumulate <- p + } else { + cp := slf.packetPool.Get() + cp.wst = wst + cp.data = packet + if len(callback) > 0 { + cp.callback = callback[0] + } + slf.packets <- cp slf.accumulation = len(slf.accumulate) + len(slf.packets) - slf.cond.L.Unlock() - return } - cp := slf.packetPool.Get() - cp.wst = wst - cp.data = packet - if len(callback) > 0 { - cp.callback = callback[0] - } - slf.cond.L.Lock() - slf.packets = append(slf.packets, cp) - slf.accumulation = len(slf.accumulate) + len(slf.packets) - slf.cond.Signal() - slf.cond.L.Unlock() + slf.mutex.Unlock() } // writeLoop 写循环 func (slf *Client) writeLoop(wait *sync.WaitGroup) { + slf.packets = make(chan *Packet, 1024*10) slf.packetPool = concurrent.NewPool[*Packet](10*1024, func() *Packet { return &Packet{} @@ -134,10 +148,11 @@ func (slf *Client) writeLoop(wait *sync.WaitGroup) { data.callback = nil }, ) - slf.cond.L.Lock() - slf.packets = append(slf.packets, slf.accumulate...) - slf.accumulate = nil - slf.cond.L.Unlock() + go func() { + for packet := range slf.accumulate { + slf.packets <- packet + } + }() defer func() { if err := recover(); err != nil { slf.Close(err.(error)) @@ -145,31 +160,19 @@ func (slf *Client) writeLoop(wait *sync.WaitGroup) { }() wait.Done() - for { - slf.cond.L.Lock() - if slf.packetPool == nil { - slf.cond.L.Unlock() - return + for packet := range slf.packets { + data := packet + var err = slf.core.Write(data) + callback := data.callback + slf.packetPool.Release(data) + if callback != nil { + callback(err) } - if len(slf.packets) == 0 { - slf.cond.Wait() - } - packets := slf.packets[0:] - slf.packets = slf.packets[0:0] - slf.cond.L.Unlock() - for i := 0; i < len(packets); i++ { - data := packets[i] - var err = slf.core.Write(data) - callback := data.callback - slf.packetPool.Release(data) - if callback != nil { - callback(err) - } - if err != nil { - panic(err) - } + if err != nil { + panic(err) } } + } func (slf *Client) onReceive(wst int, packet []byte) { diff --git a/server/conn.go b/server/conn.go index de25fa8..3437d83 100644 --- a/server/conn.go +++ b/server/conn.go @@ -4,11 +4,9 @@ import ( "context" "github.com/gorilla/websocket" "github.com/kercylan98/minotaur/utils/concurrent" - "github.com/kercylan98/minotaur/utils/log" "github.com/panjf2000/gnet" "github.com/xtaci/kcp-go/v5" "net" - "runtime/debug" "strings" "sync" ) @@ -18,7 +16,8 @@ func newKcpConn(server *Server, session *kcp.UDPSession) *Conn { c := &Conn{ ctx: server.ctx, connection: &connection{ - cond: sync.NewCond(&sync.Mutex{}), + packets: make(chan *connPacket, 1024*10), + mutex: new(sync.Mutex), server: server, remoteAddr: session.RemoteAddr(), ip: session.RemoteAddr().String(), @@ -41,7 +40,8 @@ func newGNetConn(server *Server, conn gnet.Conn) *Conn { c := &Conn{ ctx: server.ctx, connection: &connection{ - cond: sync.NewCond(&sync.Mutex{}), + packets: make(chan *connPacket, 1024*10), + mutex: new(sync.Mutex), server: server, remoteAddr: conn.RemoteAddr(), ip: conn.RemoteAddr().String(), @@ -64,7 +64,8 @@ func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn { c := &Conn{ ctx: server.ctx, connection: &connection{ - cond: sync.NewCond(&sync.Mutex{}), + packets: make(chan *connPacket, 1024*10), + mutex: new(sync.Mutex), server: server, remoteAddr: ws.RemoteAddr(), ip: ip, @@ -84,9 +85,10 @@ func newGatewayConn(conn *Conn, connId string) *Conn { c := &Conn{ //ctx: server.ctx, connection: &connection{ - cond: sync.NewCond(&sync.Mutex{}), - server: conn.server, - data: map[any]any{}, + packets: make(chan *connPacket, 1024*10), + mutex: new(sync.Mutex), + server: conn.server, + data: map[any]any{}, }, } c.gw = func(packet []byte) { @@ -100,7 +102,8 @@ func NewEmptyConn(server *Server) *Conn { c := &Conn{ ctx: server.ctx, connection: &connection{ - cond: sync.NewCond(&sync.Mutex{}), + packets: make(chan *connPacket, 1024*10), + mutex: new(sync.Mutex), server: server, remoteAddr: &net.TCPAddr{}, ip: "0.0.0.0:0", @@ -123,6 +126,7 @@ type Conn struct { // connection 长久保持的连接 type connection struct { server *Server + mutex *sync.Mutex remoteAddr net.Addr ip string ws *websocket.Conn @@ -130,9 +134,8 @@ type connection struct { kcp *kcp.UDPSession gw func(packet []byte) data map[any]any - cond *sync.Cond packetPool *concurrent.Pool[*connPacket] - packets []*connPacket + packets chan *connPacket } // IsEmpty 是否是空连接 @@ -144,12 +147,6 @@ func (slf *Conn) IsEmpty() bool { // - 重用连接时,会将当前连接的数据复制到新连接中 // - 通常在于连接断开后,重新连接时使用 func (slf *Conn) Reuse(conn *Conn) { - slf.cond.L.Lock() - conn.cond.L.Lock() - defer func() { - slf.cond.L.Unlock() - conn.cond.L.Unlock() - }() slf.Close() slf.remoteAddr = conn.remoteAddr slf.ip = conn.ip @@ -190,7 +187,10 @@ func (slf *Conn) Close() { slf.packetPool.Close() } slf.packetPool = nil - slf.packets = nil + if slf.packets != nil { + close(slf.packets) + slf.packets = nil + } } // SetData 设置连接数据,该数据将在连接关闭前始终存在 @@ -243,12 +243,14 @@ func (slf *Conn) SetWST(wst int) *Conn { // Write 向连接中写入数据 // - messageType: websocket模式中指定消息类型 func (slf *Conn) Write(packet []byte, callback ...func(err error)) { + slf.mutex.Lock() + defer slf.mutex.Unlock() if slf.gw != nil { slf.gw(packet) return } packet = slf.server.OnConnectionWritePacketBeforeEvent(slf, packet) - if slf.packetPool == nil { + if slf.packetPool == nil || slf.packets == nil { return } cp := slf.packetPool.Get() @@ -257,10 +259,7 @@ func (slf *Conn) Write(packet []byte, callback ...func(err error)) { if len(callback) > 0 { cp.callback = callback[0] } - slf.cond.L.Lock() - slf.packets = append(slf.packets, cp) - slf.cond.Signal() - slf.cond.L.Unlock() + slf.packets <- cp } // writeLoop 写循环 @@ -277,49 +276,38 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) { defer func() { if err := recover(); err != nil { slf.Close() - log.Error("WriteLoop", log.Any("Error", err)) - debug.PrintStack() + // TODO: 以下代码是否需要? + // log.Error("WriteLoop", log.Any("Error", err)) + // debug.PrintStack() } }() wait.Done() - for { - slf.cond.L.Lock() - if slf.packetPool == nil { - slf.cond.L.Unlock() - return - } - if len(slf.packets) == 0 { - slf.cond.Wait() - } - packets := slf.packets[0:] - slf.packets = slf.packets[0:0] - slf.cond.L.Unlock() - for i := 0; i < len(packets); i++ { - data := packets[i] - var err error - if slf.IsWebsocket() { - err = slf.ws.WriteMessage(data.wst, data.packet) - } else { - if slf.gn != nil { - switch slf.server.network { - case NetworkUdp, NetworkUdp4, NetworkUdp6: - err = slf.gn.SendTo(data.packet) - default: - err = slf.gn.AsyncWrite(data.packet) - } + for packet := range slf.packets { - } else if slf.kcp != nil { - _, err = slf.kcp.Write(data.packet) + data := packet + var err error + if slf.IsWebsocket() { + err = slf.ws.WriteMessage(data.wst, data.packet) + } else { + if slf.gn != nil { + switch slf.server.network { + case NetworkUdp, NetworkUdp4, NetworkUdp6: + err = slf.gn.SendTo(data.packet) + default: + err = slf.gn.AsyncWrite(data.packet) } + } else if slf.kcp != nil { + _, err = slf.kcp.Write(data.packet) } - callback := data.callback - slf.packetPool.Release(data) - if callback != nil { - callback(err) - } - if err != nil { - panic(err) - } + } + callback := data.callback + slf.packetPool.Release(data) + + if callback != nil { + callback(err) + } + if err != nil { + panic(err) } } } diff --git a/server/constants.go b/server/constants.go index b5033b0..f2119c8 100644 --- a/server/constants.go +++ b/server/constants.go @@ -22,7 +22,7 @@ const ( const ( DefaultMessageBufferSize = 1024 - DefaultMessageChannelSize = 1024 * 64 + DefaultMessageChannelSize = 1024 * 1024 DefaultAsyncPoolSize = 256 DefaultWebsocketReadDeadline = 30 * time.Second ) diff --git a/server/event.go b/server/event.go index d5f0dea..cf38d80 100644 --- a/server/event.go +++ b/server/event.go @@ -26,6 +26,7 @@ type ConnectionWritePacketBeforeEventHandle func(srv *Server, conn *Conn, packet type ShuntChannelCreatedEventHandle func(srv *Server, guid int64) type ShuntChannelClosedEventHandle func(srv *Server, guid int64) type ConnectionPacketPreprocessEventHandle func(srv *Server, conn *Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) +type MessageExecBeforeEventHandle func(srv *Server, message *Message) bool func newEvent(srv *Server) *event { return &event{ @@ -44,6 +45,7 @@ func newEvent(srv *Server) *event { shuntChannelCreatedEventHandles: slice.NewPriority[ShuntChannelCreatedEventHandle](), shuntChannelClosedEventHandles: slice.NewPriority[ShuntChannelClosedEventHandle](), connectionPacketPreprocessEventHandles: slice.NewPriority[ConnectionPacketPreprocessEventHandle](), + messageExecBeforeEventHandles: slice.NewPriority[MessageExecBeforeEventHandle](), } } @@ -63,6 +65,7 @@ type event struct { shuntChannelCreatedEventHandles *slice.Priority[ShuntChannelCreatedEventHandle] shuntChannelClosedEventHandles *slice.Priority[ShuntChannelClosedEventHandle] connectionPacketPreprocessEventHandles *slice.Priority[ConnectionPacketPreprocessEventHandle] + messageExecBeforeEventHandles *slice.Priority[MessageExecBeforeEventHandle] consoleCommandEventHandles map[string]*slice.Priority[ConsoleCommandEventHandle] consoleCommandEventHandleInitOnce sync.Once @@ -353,6 +356,34 @@ func (slf *event) OnConnectionPacketPreprocessEvent(conn *Conn, packet []byte, u return abort } +// RegMessageExecBeforeEvent 在处理消息前将立刻执行被注册的事件处理函数 +// - 当返回 true 时,将继续执行后续的消息处理函数,否则将不会执行后续的消息处理函数,并且该消息将被丢弃 +// +// 适用于限流等场景 +func (slf *event) RegMessageExecBeforeEvent(handle MessageExecBeforeEventHandle, priority ...int) { + slf.messageExecBeforeEventHandles.Append(handle, slice.GetValue(priority, 0)) + log.Info("Server", log.String("RegEvent", runtimes.CurrentRunningFuncName()), log.String("handle", reflect.TypeOf(handle).String())) +} + +// OnMessageExecBeforeEvent 执行消息处理前的事件处理函数 +func (slf *event) OnMessageExecBeforeEvent(message *Message) bool { + if slf.messageExecBeforeEventHandles.Len() == 0 { + return true + } + var result = true + defer func() { + if err := recover(); err != nil { + log.Error("Server", log.String("OnMessageExecBeforeEvent", fmt.Sprintf("%v", err))) + debug.PrintStack() + } + }() + slf.messageExecBeforeEventHandles.RangeValue(func(index int, value MessageExecBeforeEventHandle) bool { + result = value(slf.Server, message) + return result + }) + return result +} + func (slf *event) check() { switch slf.network { case NetworkHttp, NetworkGRPC, NetworkNone: diff --git a/server/message.go b/server/message.go index 4b32257..7f027e9 100644 --- a/server/message.go +++ b/server/message.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/kercylan98/minotaur/utils/hash" "reflect" ) @@ -64,6 +65,11 @@ type ( MessageErrorAction byte ) +// HasMessageType 检查是否存在指定的消息类型 +func HasMessageType(mt MessageType) bool { + return hash.Exist(messageNames, mt) +} + func (slf MessageErrorAction) String() string { return messageErrorActionNames[slf] } @@ -74,6 +80,11 @@ type Message struct { attrs []any // 消息属性 } +// MessageType 返回消息类型 +func (slf *Message) MessageType() MessageType { + return slf.t +} + // String 返回消息的字符串表示 func (slf *Message) String() string { var attrs = make([]any, 0, len(slf.attrs)) @@ -86,7 +97,9 @@ func (slf *Message) String() string { var s string switch slf.t { case MessageTypePacket: - s = messagePacketVisualization(attrs[1].([]byte)) + if len(attrs) > 1 { + s = messagePacketVisualization(attrs[1].([]byte)) + } default: if len(slf.attrs) == 0 { s = "NoneAttr" @@ -104,6 +117,13 @@ func (slf MessageType) String() string { return messageNames[slf] } +// GetPacketMessageAttrs 获取消息中的数据包属性 +func (slf *Message) GetPacketMessageAttrs() (conn *Conn, packet []byte) { + conn = slf.attrs[0].(*Conn) + packet = slf.attrs[1].([]byte) + return +} + // PushPacketMessage 向特定服务器中推送 MessageTypePacket 消息 func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ...any) { msg := srv.messagePool.Get() @@ -112,6 +132,13 @@ func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ... srv.pushMessage(msg) } +// GetErrorMessageAttrs 获取消息中的错误属性 +func (slf *Message) GetErrorMessageAttrs() (err error, action MessageErrorAction) { + err = slf.attrs[0].(error) + action = slf.attrs[1].(MessageErrorAction) + return +} + // PushErrorMessage 向特定服务器中推送 MessageTypeError 消息 func PushErrorMessage(srv *Server, err error, action MessageErrorAction, mark ...any) { msg := srv.messagePool.Get() @@ -120,6 +147,13 @@ func PushErrorMessage(srv *Server, err error, action MessageErrorAction, mark .. srv.pushMessage(msg) } +// GetCrossMessageAttrs 获取消息中的跨服属性 +func (slf *Message) GetCrossMessageAttrs() (serverId int64, packet []byte) { + serverId = slf.attrs[0].(int64) + packet = slf.attrs[1].([]byte) + return +} + // PushCrossMessage 向特定服务器中推送 MessageTypeCross 消息 func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []byte, mark ...any) { if serverId == srv.id { @@ -139,6 +173,12 @@ func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []by } } +// GetTickerMessageAttrs 获取消息中的定时器属性 +func (slf *Message) GetTickerMessageAttrs() (caller func()) { + caller = slf.attrs[0].(func()) + return +} + // PushTickerMessage 向特定服务器中推送 MessageTypeTicker 消息 func PushTickerMessage(srv *Server, caller func(), mark ...any) { msg := srv.messagePool.Get() @@ -147,6 +187,13 @@ func PushTickerMessage(srv *Server, caller func(), mark ...any) { srv.pushMessage(msg) } +// GetAsyncMessageAttrs 获取消息中的异步消息属性 +func (slf *Message) GetAsyncMessageAttrs() (caller func() error, callback func(err error), hasCallback bool) { + caller = slf.attrs[0].(func() error) + callback, hasCallback = slf.attrs[1].(func(err error)) + return +} + // PushAsyncMessage 向特定服务器中推送 MessageTypeAsync 消息 // - 异步消息将在服务器的异步消息队列中进行处理,处理完成 caller 的阻塞操作后,将会通过系统消息执行 callback 函数 // - callback 函数将在异步消息处理完成后进行调用,无论过程是否产生 err,都将被执行,允许为 nil @@ -160,6 +207,12 @@ func PushAsyncMessage(srv *Server, caller func() error, callback func(err error) srv.pushMessage(msg) } +// GetSystemMessageAttrs 获取消息中的系统消息属性 +func (slf *Message) GetSystemMessageAttrs() (handle func()) { + handle = slf.attrs[0].(func()) + return +} + // PushSystemMessage 向特定服务器中推送 MessageTypeSystem 消息 func PushSystemMessage(srv *Server, handle func(), mark ...any) { msg := srv.messagePool.Get() diff --git a/server/server.go b/server/server.go index e0072ff..754f085 100644 --- a/server/server.go +++ b/server/server.go @@ -102,18 +102,17 @@ type Server struct { } // Run 使用特定地址运行服务器 -// -// server.NetworkTcp (addr:":8888") -// server.NetworkTcp4 (addr:":8888") -// server.NetworkTcp6 (addr:":8888") -// server.NetworkUdp (addr:":8888") -// server.NetworkUdp4 (addr:":8888") -// server.NetworkUdp6 (addr:":8888") -// server.NetworkUnix (addr:"socketPath") -// server.NetworkHttp (addr:":8888") -// server.NetworkWebsocket (addr:":8888/ws") -// server.NetworkKcp (addr:":8888") -// server.NetworkNone (addr:"") +// - server.NetworkTcp (addr:":8888") +// - server.NetworkTcp4 (addr:":8888") +// - server.NetworkTcp6 (addr:":8888") +// - server.NetworkUdp (addr:":8888") +// - server.NetworkUdp4 (addr:":8888") +// - server.NetworkUdp6 (addr:":8888") +// - server.NetworkUnix (addr:"socketPath") +// - server.NetworkHttp (addr:":8888") +// - server.NetworkWebsocket (addr:":8888/ws") +// - server.NetworkKcp (addr:":8888") +// - server.NetworkNone (addr:"") func (slf *Server) Run(addr string) error { if slf.network == NetworkNone { addr = "-" @@ -142,12 +141,13 @@ func (slf *Server) Run(addr string) error { if callback != nil { go callback() } - go func() { + go func(messageChannel <-chan *Message) { messageInitFinish <- struct{}{} - for message := range slf.messageChannel { - slf.dispatchMessage(message) + for message := range messageChannel { + msg := message + slf.dispatchMessage(msg) } - }() + }(slf.messageChannel) } switch slf.network { @@ -364,6 +364,16 @@ func (slf *Server) RunNone() error { return slf.Run(str.None) } +// Context 获取服务器上下文 +func (slf *Server) Context() context.Context { + return slf.ctx +} + +// TimeoutContext 获取服务器超时上下文,context.WithTimeout 的简写 +func (slf *Server) TimeoutContext(timeout time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(slf.ctx, timeout) +} + // GetOnlineCount 获取在线人数 func (slf *Server) GetOnlineCount() int { return slf.online.Size() @@ -541,13 +551,10 @@ func (slf *Server) ShuntChannelFreed(channelGuid int64) { // pushMessage 向服务器中写入特定类型的消息,需严格遵守消息属性要求 func (slf *Server) pushMessage(message *Message) { - if slf.messagePool.IsClose() { + if slf.messagePool.IsClose() || slf.isShutdown.Load() || !slf.OnMessageExecBeforeEvent(message) { slf.messagePool.Release(message) return } - if slf.isShutdown.Load() { - return - } if slf.shuntChannels != nil && message.t == MessageTypePacket { conn := message.attrs[0].(*Conn) channelGuid, allowToCreate := slf.shuntMatcher(conn) @@ -568,6 +575,7 @@ func (slf *Server) pushMessage(message *Message) { } } slf.messageChannel <- message + } func (slf *Server) low(message *Message, present time.Time, expect time.Duration, messageReplace ...string) { @@ -593,18 +601,18 @@ func (slf *Server) dispatchMessage(msg *Message) { ) if slf.deadlockDetect > 0 { ctx, cancel = context.WithTimeout(context.Background(), slf.deadlockDetect) - go func() { + go func(ctx context.Context, msg *Message) { select { case <-ctx.Done(): if err := ctx.Err(); err == context.DeadlineExceeded { log.Warn("Server", log.String("MessageType", messageNames[msg.t]), log.Any("SuspectedDeadlock", msg.attrs)) } } - }() + }(ctx, msg) } present := time.Now() - defer func() { + defer func(msg *Message) { if err := recover(); err != nil { stack := string(debug.Stack()) log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("MessageAttrs", msg.attrs), log.Any("error", err), log.String("stack", stack)) @@ -626,17 +634,16 @@ func (slf *Server) dispatchMessage(msg *Message) { slf.messagePool.Release(msg) } - }() + }(msg) var attrs = msg.attrs switch msg.t { case MessageTypePacket: - var conn = attrs[0].(*Conn) - var packet = attrs[1].([]byte) + var conn, packet = msg.GetPacketMessageAttrs() if !slf.OnConnectionPacketPreprocessEvent(conn, packet, func(newPacket []byte) { packet = newPacket }) { slf.OnConnectionReceivePacketEvent(conn, packet) } case MessageTypeError: - err, action := attrs[0].(error), attrs[1].(MessageErrorAction) + var err, action = msg.GetErrorMessageAttrs() switch action { case MessageErrorActionNone: log.Panic("Server", log.Err(err)) @@ -646,12 +653,11 @@ func (slf *Server) dispatchMessage(msg *Message) { log.Warn("Server", log.String("not support message error action", action.String())) } case MessageTypeCross: - slf.OnReceiveCrossPacketEvent(attrs[0].(int64), attrs[1].([]byte)) + slf.OnReceiveCrossPacketEvent(msg.GetCrossMessageAttrs()) case MessageTypeTicker: - attrs[0].(func())() + msg.GetTickerMessageAttrs()() case MessageTypeAsync: - handle := attrs[0].(func() error) - callback, cb := attrs[1].(func(err error)) + handle, callback, cb := msg.GetAsyncMessageAttrs() if err := slf.ants.Submit(func() { defer func() { if err := recover(); err != nil { @@ -686,10 +692,10 @@ func (slf *Server) dispatchMessage(msg *Message) { }); err != nil { panic(err) } - case MessageTypeAsyncCallback: + case MessageTypeAsyncCallback: // 特殊类型 attrs[0].(func())() case MessageTypeSystem: - attrs[0].(func())() + msg.GetSystemMessageAttrs()() default: log.Warn("Server", log.String("not support message type", msg.t.String())) } diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..3b056e0 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,54 @@ +package server_test + +import ( + "fmt" + "github.com/kercylan98/minotaur/server" + "github.com/kercylan98/minotaur/server/client" + "github.com/kercylan98/minotaur/utils/times" + "golang.org/x/time/rate" + "sync/atomic" + "testing" + "time" +) + +func TestNew(t *testing.T) { + limiter := rate.NewLimiter(rate.Every(time.Second), 100) + srv := server.New(server.NetworkWebsocket, server.WithMessageBufferSize(1024*1024), server.WithPProf()) + srv.RegMessageExecBeforeEvent(func(srv *server.Server, message *server.Message) bool { + t, c := srv.TimeoutContext(time.Second * 5) + defer c() + if err := limiter.Wait(t); err != nil { + return false + } + return true + }) + srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { + conn.Write(packet) + }) + if err := srv.Run(":9999"); err != nil { + panic(err) + } +} + +func TestNewClient(t *testing.T) { + var total atomic.Int64 + for i := 0; i < 1000; i++ { + cli := client.NewWebsocket("ws://127.0.0.1:9999") + cli.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) { + fmt.Println(string(packet)) + }) + cli.RegConnectionOpenedEvent(func(conn *client.Client) { + go func() { + for { + cli.WriteWS(2, []byte("hello")) + total.Add(1) + } + }() + }) + if err := cli.Run(); err != nil { + panic(err) + } + } + + time.Sleep(times.Week) +}