diff --git a/server/internal/dispatcher/dispatcher.go b/server/internal/dispatcher/dispatcher.go index 985e28b..7d696cd 100644 --- a/server/internal/dispatcher/dispatcher.go +++ b/server/internal/dispatcher/dispatcher.go @@ -137,9 +137,19 @@ func (d *Dispatcher[P, M]) IncrCount(producer P, i int64) { d.lock.Lock() defer d.lock.Unlock() d.mc += i - d.pmc[producer] += i - if d.expel && d.mc <= 0 { - close(d.abort) + pmc := d.pmc[producer] + i + d.pmc[producer] = pmc + if d.mc <= 0 { + if f := d.pmcF[producer]; f != nil && pmc <= 0 { + func(producer P) { + defer func(producer P) { + if err := super.RecoverTransform(recover()); err != nil { + log.Error("Dispatcher.ProducerDoneHandler", log.Any("producer", producer), log.Err(err)) + } + }(producer) + f(producer, &Action[P, M]{d: d, unlock: true}) + }(producer) + } } } diff --git a/server/message.go b/server/message.go index e86ebe1..05efd04 100644 --- a/server/message.go +++ b/server/message.go @@ -1,6 +1,7 @@ package server import ( + "github.com/kercylan98/minotaur/server/internal/dispatcher" "github.com/kercylan98/minotaur/utils/collection" "github.com/kercylan98/minotaur/utils/log" "github.com/kercylan98/minotaur/utils/super" @@ -75,16 +76,23 @@ func HasMessageType(mt MessageType) bool { // Message 服务器消息 type Message struct { - producer string + dis *dispatcher.Dispatcher[string, *Message] // 指定消息发送到特定的分发器 conn *Conn + err error ordinaryHandler func() exceptionHandler func() error errHandler func(err error) + marks []log.Field packet []byte - err error + producer string name string t MessageType - marks []log.Field +} + +// bindDispatcher 绑定分发器 +func (slf *Message) bindDispatcher(dis *dispatcher.Dispatcher[string, *Message]) *Message { + slf.dis = dis + return slf } func (slf *Message) GetProducer() string { @@ -103,6 +111,7 @@ func (slf *Message) reset() { slf.t = 0 slf.marks = nil slf.producer = "" + slf.dis = nil } // MessageType 返回消息类型 diff --git a/server/server.go b/server/server.go index c5c8d73..bcf1082 100644 --- a/server/server.go +++ b/server/server.go @@ -294,11 +294,8 @@ func (srv *Server) GetMessageCount() int64 { } // UseShunt 切换连接所使用的消息分流渠道,当分流渠道 name 不存在时将会创建一个新的分流渠道,否则将会加入已存在的分流渠道 -// - 默认情况下,所有连接都使用系统通道进行消息分发,当指定消息分流渠道时,将会使用指定的消息分流渠道进行消息分发 +// - 默认情况下,所有连接都使用系统通道进行消息分发,当指定消息分流渠道且为分流消息类型时,将会使用指定的消息分流渠道进行消息分发 // - 分流渠道会在连接断开时标记为驱逐状态,当分流渠道中的所有消息处理完毕且没有新连接使用时,将会被清除 -// -// 一些有趣的情况: -// - 当连接发送异步消息时,消息会被分为两部分,分别是异步部分和回调部分。异步部分会在当前的分流渠道中处理,而回调部分则是根据回调时所在的分流渠道进行处理 func (srv *Server) UseShunt(conn *Conn, name string) { srv.dispatcherMgr.BindProducer(conn.GetID(), name) } @@ -324,15 +321,17 @@ func (srv *Server) pushMessage(message *Message) { srv.messagePool.Release(message) return } - var d *dispatcher.Dispatcher[string, *Message] - switch message.t { - case MessageTypePacket, - MessageTypeShuntTicker, MessageTypeShuntAsync, MessageTypeShuntAsyncCallback, - MessageTypeUniqueShuntAsync, MessageTypeUniqueShuntAsyncCallback, - MessageTypeShunt: - d = srv.dispatcherMgr.GetDispatcher(message.conn.GetID()) - case MessageTypeSystem, MessageTypeAsync, MessageTypeUniqueAsync, MessageTypeAsyncCallback, MessageTypeUniqueAsyncCallback, MessageTypeTicker: - d = srv.dispatcherMgr.GetSystemDispatcher() + var d = message.dis + if d == nil { + switch message.t { + case MessageTypePacket, + MessageTypeShuntTicker, MessageTypeShuntAsync, MessageTypeShuntAsyncCallback, + MessageTypeUniqueShuntAsync, MessageTypeUniqueShuntAsyncCallback, + MessageTypeShunt: + d = srv.dispatcherMgr.GetDispatcher(message.conn.GetID()) + case MessageTypeSystem, MessageTypeAsync, MessageTypeUniqueAsync, MessageTypeAsyncCallback, MessageTypeUniqueAsyncCallback, MessageTypeTicker: + d = srv.dispatcherMgr.GetSystemDispatcher() + } } if d == nil { return @@ -403,8 +402,12 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher.Dispatcher[string, fmt.Println(stack) srv.OnMessageErrorEvent(msg, err) } - if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback { + switch msg.t { + case MessageTypeAsyncCallback, MessageTypeShuntAsyncCallback: + dispatcherIns.IncrCount(msg.producer, -1) + case MessageTypeUniqueAsyncCallback, MessageTypeUniqueShuntAsyncCallback: dispatcherIns.AntiUnique(msg.name) + dispatcherIns.IncrCount(msg.producer, -1) } srv.low(msg, present, time.Millisecond*100) @@ -455,25 +458,27 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher.Dispatcher[string, }(cancel, srv, dispatcherIns, msg, present) var err error if msg.exceptionHandler != nil { + dispatcherIns.IncrCount(msg.producer, 1) err = msg.exceptionHandler() } if msg.errHandler != nil { if msg.conn == nil { if msg.t == MessageTypeUniqueAsync { - srv.PushUniqueAsyncCallbackMessage(msg.name, err, msg.errHandler) + srv.pushUniqueAsyncCallbackMessage(dispatcherIns, msg.name, err, msg.errHandler) return } - srv.PushAsyncCallbackMessage(err, msg.errHandler) + srv.pushAsyncCallbackMessage(dispatcherIns, err, msg.errHandler) return } if msg.t == MessageTypeUniqueShuntAsync { - srv.PushUniqueShuntAsyncCallbackMessage(msg.conn, msg.name, err, msg.errHandler) + srv.pushUniqueShuntAsyncCallbackMessage(dispatcherIns, msg.conn, msg.name, err, msg.errHandler) return } - srv.PushShuntAsyncCallbackMessage(msg.conn, err, msg.errHandler) + srv.pushShuntAsyncCallbackMessage(dispatcherIns, msg.conn, err, msg.errHandler) return } dispatcherIns.AntiUnique(msg.name) + dispatcherIns.IncrCount(msg.producer, -1) if err != nil { log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", string(debug.Stack()))) } @@ -505,11 +510,11 @@ func (srv *Server) PushAsyncMessage(caller func() error, callback func(err error srv.pushMessage(srv.messagePool.Get().castToAsyncMessage(caller, callback, mark...)) } -// PushAsyncCallbackMessage 向服务器中推送 MessageTypeAsyncCallback 消息 +// pushAsyncCallbackMessage 向服务器中推送 MessageTypeAsyncCallback 消息 // - 异步消息回调将会通过一个接收 error 的函数进行处理,该函数将在系统分发器中执行 // - mark 为可选的日志标记,当发生异常时,将会在日志中进行体现 -func (srv *Server) PushAsyncCallbackMessage(err error, callback func(err error), mark ...log.Field) { - srv.pushMessage(srv.messagePool.Get().castToAsyncCallbackMessage(err, callback, mark...)) +func (srv *Server) pushAsyncCallbackMessage(dis *dispatcher.Dispatcher[string, *Message], err error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToAsyncCallbackMessage(err, callback, mark...).bindDispatcher(dis)) } // PushShuntAsyncMessage 向特定分发器中推送 MessageTypeAsync 消息,消息执行与 MessageTypeAsync 一致 @@ -519,10 +524,10 @@ func (srv *Server) PushShuntAsyncMessage(conn *Conn, caller func() error, callba srv.pushMessage(srv.messagePool.Get().castToShuntAsyncMessage(conn, caller, callback, mark...)) } -// PushShuntAsyncCallbackMessage 向特定分发器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致 -// - 需要注意的是,当未指定 UseShunt 时,将会通过 PushAsyncCallbackMessage 进行转发 -func (srv *Server) PushShuntAsyncCallbackMessage(conn *Conn, err error, callback func(err error), mark ...log.Field) { - srv.pushMessage(srv.messagePool.Get().castToShuntAsyncCallbackMessage(conn, err, callback, mark...)) +// pushShuntAsyncCallbackMessage 向特定分发器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致 +// - 需要注意的是,当未指定 UseShunt 时,将会通过 pushAsyncCallbackMessage 进行转发 +func (srv *Server) pushShuntAsyncCallbackMessage(dis *dispatcher.Dispatcher[string, *Message], conn *Conn, err error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToShuntAsyncCallbackMessage(conn, err, callback, mark...).bindDispatcher(dis)) } // PushPacketMessage 向服务器中推送 MessageTypePacket 消息 @@ -558,9 +563,9 @@ func (srv *Server) PushUniqueAsyncMessage(unique string, caller func() error, ca srv.pushMessage(srv.messagePool.Get().castToUniqueAsyncMessage(unique, caller, callback, mark...)) } -// PushUniqueAsyncCallbackMessage 向服务器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致 -func (srv *Server) PushUniqueAsyncCallbackMessage(unique string, err error, callback func(err error), mark ...log.Field) { - srv.pushMessage(srv.messagePool.Get().castToUniqueAsyncCallbackMessage(unique, err, callback, mark...)) +// pushUniqueAsyncCallbackMessage 向服务器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致 +func (srv *Server) pushUniqueAsyncCallbackMessage(dis *dispatcher.Dispatcher[string, *Message], unique string, err error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToUniqueAsyncCallbackMessage(unique, err, callback, mark...).bindDispatcher(dis)) } // PushUniqueShuntAsyncMessage 向特定分发器中推送 MessageTypeAsync 消息,消息执行与 MessageTypeAsync 一致 @@ -570,10 +575,10 @@ func (srv *Server) PushUniqueShuntAsyncMessage(conn *Conn, unique string, caller srv.pushMessage(srv.messagePool.Get().castToUniqueShuntAsyncMessage(conn, unique, caller, callback, mark...)) } -// PushUniqueShuntAsyncCallbackMessage 向特定分发器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致 +// pushUniqueShuntAsyncCallbackMessage 向特定分发器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致 // - 需要注意的是,当未指定 UseShunt 时,将会通过系统分流渠道进行转发 -func (srv *Server) PushUniqueShuntAsyncCallbackMessage(conn *Conn, unique string, err error, callback func(err error), mark ...log.Field) { - srv.pushMessage(srv.messagePool.Get().castToUniqueShuntAsyncCallbackMessage(conn, unique, err, callback, mark...)) +func (srv *Server) pushUniqueShuntAsyncCallbackMessage(dis *dispatcher.Dispatcher[string, *Message], conn *Conn, unique string, err error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToUniqueShuntAsyncCallbackMessage(conn, unique, err, callback, mark...).bindDispatcher(dis)) } // PushShuntMessage 向特定分发器中推送 MessageTypeShunt 消息,消息执行与 MessageTypeSystem 一致,不同的是将会在特定分发器中执行