From 811e1bd29ec4c4859a439c7bdc9655cd8abea635 Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Mon, 21 Aug 2023 18:48:52 +0800 Subject: [PATCH] =?UTF-8?q?other:=20server=20=E5=BC=82=E6=AD=A5=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E5=9B=9E=E8=B0=83=E5=B0=86=E4=B8=8D=E5=86=8D=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=20MessageTypeSystem=EF=BC=8C=E6=9B=B4=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=20MessageTypeAsyncCallback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/message.go | 18 ++++++++++++------ server/server.go | 15 +++++++++++---- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/server/message.go b/server/message.go index 4dac411..8f7c17f 100644 --- a/server/message.go +++ b/server/message.go @@ -22,17 +22,21 @@ const ( // MessageTypeAsync 异步消息类型 MessageTypeAsync + // MessageTypeAsyncCallback 异步回调消息类型 + MessageTypeAsyncCallback + // MessageTypeSystem 系统消息类型 MessageTypeSystem ) var messageNames = map[MessageType]string{ - MessageTypePacket: "MessageTypePacket", - MessageTypeError: "MessageTypeError", - MessageTypeCross: "MessageTypeCross", - MessageTypeTicker: "MessageTypeTicker", - MessageTypeAsync: "MessageTypeAsync", - MessageTypeSystem: "MessageTypeSystem", + MessageTypePacket: "MessageTypePacket", + MessageTypeError: "MessageTypeError", + MessageTypeCross: "MessageTypeCross", + MessageTypeTicker: "MessageTypeTicker", + MessageTypeAsync: "MessageTypeAsync", + MessageTypeAsyncCallback: "MessageTypeAsyncCallback", + MessageTypeSystem: "MessageTypeSystem", } const ( @@ -146,6 +150,8 @@ func PushTickerMessage(srv *Server, caller func(), mark ...any) { // - 异步消息将在服务器的异步消息队列中进行处理,处理完成 caller 的阻塞操作后,将会通过系统消息执行 callback 函数 // - callback 函数将在异步消息处理完成后进行调用,无论过程是否产生 err,都将被执行,允许为 nil // - 需要注意的是,为了避免并发问题,caller 函数请仅处理阻塞操作,其他操作应该在 callback 函数中进行 +// +// 在通过 WithShunt 使用分流服务器时,异步消息不会转换到分流通道中进行处理。依旧需要注意上方第三条 func PushAsyncMessage(srv *Server, caller func() error, callback func(err error), mark ...any) { msg := srv.messagePool.Get() msg.t = MessageTypeAsync diff --git a/server/server.go b/server/server.go index 349b5bc..78b17ee 100644 --- a/server/server.go +++ b/server/server.go @@ -545,7 +545,7 @@ func (slf *Server) pushMessage(message *Message) { if slf.isShutdown.Load() { return } - if slf.shuntChannels != nil && (message.t == MessageTypePacket) { + if slf.shuntChannels != nil && message.t == MessageTypePacket { conn := message.attrs[0].(*Conn) channelGuid, allowToCreate := slf.shuntMatcher(conn) channel, exist := slf.shuntChannels.GetExist(channelGuid) @@ -681,15 +681,22 @@ func (slf *Server) dispatchMessage(msg *Message) { }() err := handle() if cb && callback != nil { - PushSystemMessage(slf, func() { - callback(err) - }, "AsyncCallback") + acm := slf.messagePool.Get() + acm.t = MessageTypeAsyncCallback + if len(attrs) > 2 { + acm.attrs = append([]any{func() { callback(err) }}, attrs[2:]...) + } else { + acm.attrs = []any{func() { callback(err) }} + } + slf.pushMessage(acm) } else if err != nil { log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", string(debug.Stack()))) } }); err != nil { panic(err) } + case MessageTypeAsyncCallback: + attrs[0].(func())() case MessageTypeSystem: attrs[0].(func())() default: