From a2a9199d415e9e54e33adc66a2e4dc0d20fff9b5 Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Fri, 12 Jan 2024 13:48:57 +0800 Subject: [PATCH] =?UTF-8?q?test:=20=E5=AE=8C=E5=96=84=20dispatcher.Dispatc?= =?UTF-8?q?her=20=E6=B3=A8=E9=87=8A=E5=8F=8A=E6=B5=8B=E8=AF=95=E7=94=A8?= =?UTF-8?q?=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/conn.go | 2 +- server/{hub.go => conn_mgr.go} | 30 +-- server/internal/dispatcher/dispatcher.go | 21 +- server/internal/dispatcher/dispatcher_test.go | 196 +++++++++++++++--- server/server.go | 8 +- 5 files changed, 208 insertions(+), 49 deletions(-) rename server/{hub.go => conn_mgr.go} (80%) diff --git a/server/conn.go b/server/conn.go index c994a74..9d9784d 100644 --- a/server/conn.go +++ b/server/conn.go @@ -125,7 +125,7 @@ type connection struct { gw func(packet []byte) data map[any]any closed bool - pool *hub.Pool[*connPacket] + pool *hub.ObjectPool[*connPacket] loop writeloop.WriteLoop[*connPacket] mu sync.Mutex openTime time.Time diff --git a/server/hub.go b/server/conn_mgr.go similarity index 80% rename from server/hub.go rename to server/conn_mgr.go index cd56dfb..c60f6a5 100644 --- a/server/hub.go +++ b/server/conn_mgr.go @@ -6,7 +6,7 @@ import ( "sync" ) -type hub struct { +type connMgr struct { connections map[string]*Conn // 所有连接 register chan *Conn // 注册连接 @@ -26,12 +26,12 @@ type hubBroadcast struct { filter func(conn *Conn) bool // 过滤掉返回 false 的连接 } -func (h *hub) run(ctx context.Context) { +func (h *connMgr) run(ctx context.Context) { h.connections = make(map[string]*Conn) h.register = make(chan *Conn, DefaultConnHubBufferSize) h.unregister = make(chan string, DefaultConnHubBufferSize) h.broadcast = make(chan hubBroadcast, DefaultConnHubBufferSize) - go func(ctx context.Context, h *hub) { + go func(ctx context.Context, h *connMgr) { for { select { case conn := <-h.register: @@ -54,7 +54,7 @@ func (h *hub) run(ctx context.Context) { } // registerConn 注册连接 -func (h *hub) registerConn(conn *Conn) { +func (h *connMgr) registerConn(conn *Conn) { select { case h.register <- conn: default: @@ -63,7 +63,7 @@ func (h *hub) registerConn(conn *Conn) { } // unregisterConn 注销连接 -func (h *hub) unregisterConn(id string) { +func (h *connMgr) unregisterConn(id string) { select { case h.unregister <- id: default: @@ -72,21 +72,21 @@ func (h *hub) unregisterConn(id string) { } // GetOnlineCount 获取在线人数 -func (h *hub) GetOnlineCount() int { +func (h *connMgr) GetOnlineCount() int { h.chanMutex.RLock() defer h.chanMutex.RUnlock() return h.onlineCount } // GetOnlineBotCount 获取在线机器人数量 -func (h *hub) GetOnlineBotCount() int { +func (h *connMgr) GetOnlineBotCount() int { h.chanMutex.RLock() defer h.chanMutex.RUnlock() return h.botCount } // IsOnline 是否在线 -func (h *hub) IsOnline(id string) bool { +func (h *connMgr) IsOnline(id string) bool { h.chanMutex.RLock() _, exist := h.connections[id] h.chanMutex.RUnlock() @@ -94,7 +94,7 @@ func (h *hub) IsOnline(id string) bool { } // GetOnlineAll 获取所有在线连接 -func (h *hub) GetOnlineAll() map[string]*Conn { +func (h *connMgr) GetOnlineAll() map[string]*Conn { h.chanMutex.RLock() cop := collection.CloneMap(h.connections) h.chanMutex.RUnlock() @@ -102,7 +102,7 @@ func (h *hub) GetOnlineAll() map[string]*Conn { } // GetOnline 获取在线连接 -func (h *hub) GetOnline(id string) *Conn { +func (h *connMgr) GetOnline(id string) *Conn { h.chanMutex.RLock() conn := h.connections[id] h.chanMutex.RUnlock() @@ -110,7 +110,7 @@ func (h *hub) GetOnline(id string) *Conn { } // CloseConn 关闭连接 -func (h *hub) CloseConn(id string) { +func (h *connMgr) CloseConn(id string) { h.chanMutex.RLock() conn := h.connections[id] h.chanMutex.RUnlock() @@ -120,7 +120,7 @@ func (h *hub) CloseConn(id string) { } // Broadcast 广播消息 -func (h *hub) Broadcast(packet []byte, filter ...func(conn *Conn) bool) { +func (h *connMgr) Broadcast(packet []byte, filter ...func(conn *Conn) bool) { m := hubBroadcast{ packet: packet, } @@ -134,7 +134,7 @@ func (h *hub) Broadcast(packet []byte, filter ...func(conn *Conn) bool) { } } -func (h *hub) onRegister(conn *Conn) { +func (h *connMgr) onRegister(conn *Conn) { h.chanMutex.Lock() if h.closed { conn.Close() @@ -148,7 +148,7 @@ func (h *hub) onRegister(conn *Conn) { h.chanMutex.Unlock() } -func (h *hub) onUnregister(id string) { +func (h *connMgr) onUnregister(id string) { h.chanMutex.Lock() if conn, ok := h.connections[id]; ok { h.onlineCount-- @@ -160,7 +160,7 @@ func (h *hub) onUnregister(id string) { h.chanMutex.Unlock() } -func (h *hub) onBroadcast(packet hubBroadcast) { +func (h *connMgr) onBroadcast(packet hubBroadcast) { h.chanMutex.RLock() defer h.chanMutex.RUnlock() for _, conn := range h.connections { diff --git a/server/internal/dispatcher/dispatcher.go b/server/internal/dispatcher/dispatcher.go index 9e72f03..4c6a4e0 100644 --- a/server/internal/dispatcher/dispatcher.go +++ b/server/internal/dispatcher/dispatcher.go @@ -1,6 +1,7 @@ package dispatcher import ( + "fmt" "github.com/alphadose/haxmap" "github.com/kercylan98/minotaur/utils/buffer" "sync" @@ -12,8 +13,11 @@ var unique = struct{}{} // Handler 消息处理器 type Handler[P Producer, M Message[P]] func(dispatcher *Dispatcher[P, M], message M) -// NewDispatcher 生成消息分发器 +// NewDispatcher 创建一个新的消息分发器 Dispatcher 实例 func NewDispatcher[P Producer, M Message[P]](bufferSize int, name string, handler Handler[P, M]) *Dispatcher[P, M] { + if bufferSize <= 0 || handler == nil { + panic(fmt.Errorf("bufferSize must be greater than 0 and handler must not be nil, but got bufferSize: %d, handler is nil: %v", bufferSize, handler == nil)) + } d := &Dispatcher[P, M]{ name: name, buf: buffer.NewRingUnbounded[M](bufferSize), @@ -26,7 +30,19 @@ func NewDispatcher[P Producer, M Message[P]](bufferSize int, name string, handle return d } -// Dispatcher 消息分发器 +// Dispatcher 用于服务器消息处理的消息分发器 +// +// 这个消息分发器为并发安全的生产者和消费者模型,生产者可以是任意类型,消费者必须是 Message 接口的实现。 +// 生产者可以通过 Put 方法并发安全地将消息放入消息分发器,消息执行过程不会阻塞到 Put 方法,同时允许在 Start 方法之前调用 Put 方法。 +// 在执行 Start 方法后,消息分发器会阻塞地从消息缓冲区中读取消息,然后执行消息处理器,消息处理器的执行过程不会阻塞到消息的生产。 +// +// 为了保证消息不丢失,内部采用了 buffer.RingUnbounded 作为缓冲区实现,并且消息分发器不提供 Close 方法。 +// 如果需要关闭消息分发器,可以通过 Expel 方法设置驱逐计划,当消息分发器中没有任何消息时,将会被释放。 +// 同时,也可以使用 UnExpel 方法取消驱逐计划。 +// +// 为什么提供 Expel 和 UnExpel 方法: +// - 在连接断开时,当需要执行一系列消息处理时,如果直接关闭消息分发器,可能会导致消息丢失。所以提供了 Expel 方法,可以在消息处理完成后再关闭消息分发器。 +// - 当消息还未处理完成时连接重连,如果没有取消驱逐计划,可能会导致消息分发器被关闭。所以提供了 UnExpel 方法,可以在连接重连后取消驱逐计划。 type Dispatcher[P Producer, M Message[P]] struct { buf *buffer.RingUnbounded[M] uniques *haxmap.Map[string, struct{}] @@ -42,6 +58,7 @@ type Dispatcher[P Producer, M Message[P]] struct { } // SetProducerDoneHandler 设置特定生产者所有消息处理完成时的回调函数 +// - 如果 handler 为 nil,则会删除该生产者的回调函数 func (d *Dispatcher[P, M]) SetProducerDoneHandler(p P, handler func(p P, dispatcher *Dispatcher[P, M])) *Dispatcher[P, M] { d.lock.Lock() if handler == nil { diff --git a/server/internal/dispatcher/dispatcher_test.go b/server/internal/dispatcher/dispatcher_test.go index 456cc2c..74ca683 100644 --- a/server/internal/dispatcher/dispatcher_test.go +++ b/server/internal/dispatcher/dispatcher_test.go @@ -3,6 +3,7 @@ package dispatcher_test import ( "github.com/kercylan98/minotaur/server/internal/dispatcher" "sync" + "sync/atomic" "testing" "time" ) @@ -16,33 +17,174 @@ func (m *TestMessage) GetProducer() string { return m.producer } -func TestDispatcher_PutStartClose(t *testing.T) { - // 写入完成后,关闭分发器再开始分发,确保消息不会丢失 - w := new(sync.WaitGroup) - cw := new(sync.WaitGroup) - cw.Add(1) - d := dispatcher.NewDispatcher[string, *TestMessage](1024*16, "test", func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) { - t.Log(message) - w.Done() - }).SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) { - t.Log("closed") - cw.Done() - }) - - for i := 0; i < 100; i++ { - w.Add(1) - d.Put(&TestMessage{ - producer: "test", - v: i, - }) +func TestNewDispatcher(t *testing.T) { + var cases = []struct { + name string + bufferSize int + handler dispatcher.Handler[string, *TestMessage] + shouldPanic bool + }{ + {name: "TestNewDispatcher_BufferSize0AndHandlerNil", bufferSize: 0, handler: nil, shouldPanic: true}, + {name: "TestNewDispatcher_BufferSize0AndHandlerNotNil", bufferSize: 0, handler: func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {}, shouldPanic: true}, + {name: "TestNewDispatcher_BufferSize1AndHandlerNil", bufferSize: 1, handler: nil, shouldPanic: true}, + {name: "TestNewDispatcher_BufferSize1AndHandlerNotNil", bufferSize: 1, handler: func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {}, shouldPanic: false}, } - d.Start() - d.Expel() - d.UnExpel() - w.Wait() - time.Sleep(time.Second) - d.Expel() - cw.Wait() - t.Log("done") + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil && !c.shouldPanic { + t.Errorf("NewDispatcher() should not panic, but panic: %v", r) + } + }() + dispatcher.NewDispatcher(c.bufferSize, c.name, c.handler) + }) + } +} + +func TestDispatcher_SetProducerDoneHandler(t *testing.T) { + var cases = []struct { + name string + producer string + messageFinish *atomic.Bool + cancel bool + }{ + {name: "TestDispatcher_SetProducerDoneHandlerNotCancel", producer: "producer", cancel: false}, + {name: "TestDispatcher_SetProducerDoneHandlerCancel", producer: "producer", cancel: true}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + c.messageFinish = &atomic.Bool{} + w := new(sync.WaitGroup) + d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) { + w.Done() + }) + d.SetProducerDoneHandler("producer", func(p string, dispatcher *dispatcher.Dispatcher[string, *TestMessage]) { + c.messageFinish.Store(true) + }) + if c.cancel { + d.SetProducerDoneHandler("producer", nil) + } + d.Put(&TestMessage{producer: "producer"}) + w.Add(1) + d.Start() + w.Wait() + if c.messageFinish.Load() && c.cancel { + t.Errorf("%s should cancel, but not", c.name) + } + }) + } +} + +func TestDispatcher_SetClosedHandler(t *testing.T) { + var cases = []struct { + name string + handlerFinishMsgCount *atomic.Int64 + msgTime time.Duration + msgCount int + }{ + {name: "TestDispatcher_SetClosedHandler_Normal", msgTime: 0, msgCount: 1}, + {name: "TestDispatcher_SetClosedHandler_MessageCount1024", msgTime: 0, msgCount: 1024}, + {name: "TestDispatcher_SetClosedHandler_MessageTime1sMessageCount3", msgTime: 1 * time.Second, msgCount: 3}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + c.handlerFinishMsgCount = &atomic.Int64{} + w := new(sync.WaitGroup) + d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) { + time.Sleep(c.msgTime) + c.handlerFinishMsgCount.Add(1) + }) + d.SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) { + w.Done() + }) + for i := 0; i < c.msgCount; i++ { + d.Put(&TestMessage{producer: "producer"}) + } + w.Add(1) + d.Start() + d.Expel() + w.Wait() + if c.handlerFinishMsgCount.Load() != int64(c.msgCount) { + t.Errorf("%s should finish %d messages, but finish %d", c.name, c.msgCount, c.handlerFinishMsgCount.Load()) + } + }) + } +} + +func TestDispatcher_Expel(t *testing.T) { + var cases = []struct { + name string + handlerFinishMsgCount *atomic.Int64 + msgTime time.Duration + msgCount int + }{ + {name: "TestDispatcher_Expel_Normal", msgTime: 0, msgCount: 1}, + {name: "TestDispatcher_Expel_MessageCount1024", msgTime: 0, msgCount: 1024}, + {name: "TestDispatcher_Expel_MessageTime1sMessageCount3", msgTime: 1 * time.Second, msgCount: 3}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + c.handlerFinishMsgCount = &atomic.Int64{} + w := new(sync.WaitGroup) + d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) { + time.Sleep(c.msgTime) + c.handlerFinishMsgCount.Add(1) + }) + d.SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) { + w.Done() + }) + for i := 0; i < c.msgCount; i++ { + d.Put(&TestMessage{producer: "producer"}) + } + w.Add(1) + d.Start() + d.Expel() + w.Wait() + if c.handlerFinishMsgCount.Load() != int64(c.msgCount) { + t.Errorf("%s should finish %d messages, but finish %d", c.name, c.msgCount, c.handlerFinishMsgCount.Load()) + } + }) + } +} + +func TestDispatcher_UnExpel(t *testing.T) { + var cases = []struct { + name string + closed *atomic.Bool + isUnExpel bool + expect bool + }{ + {name: "TestDispatcher_UnExpel_Normal", isUnExpel: true, expect: false}, + {name: "TestDispatcher_UnExpel_NotExpel", isUnExpel: false, expect: true}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + c.closed = &atomic.Bool{} + w := new(sync.WaitGroup) + d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) { + w.Done() + }) + d.SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) { + c.closed.Store(true) + }) + d.Put(&TestMessage{producer: "producer"}) + w.Add(1) + if c.isUnExpel { + d.Expel() + d.UnExpel() + } else { + d.Expel() + } + d.Start() + w.Wait() + if c.closed.Load() != c.expect { + t.Errorf("%s should %v, but %v", c.name, c.expect, c.closed.Load()) + } + }) + } } diff --git a/server/server.go b/server/server.go index 53d4a29..e03d445 100644 --- a/server/server.go +++ b/server/server.go @@ -36,7 +36,7 @@ func New(network Network, options ...Option) *Server { connWriteBufferSize: DefaultConnWriteBufferSize, dispatcherBufferSize: DefaultDispatcherBufferSize, }, - hub: &hub{}, + connMgr: &connMgr{}, option: &option{}, network: network, closeChannel: make(chan struct{}, 1), @@ -68,7 +68,7 @@ type Server struct { *event // 事件 *runtime // 运行时 *option // 可选项 - *hub // 连接集合 + *connMgr // 连接集合 dispatcherMgr *dispatcher.Manager[string, *Message] // 消息分发器管理器 ginServer *gin.Engine // HTTP模式下的路由器 httpServer *http.Server // HTTP模式下的服务器 @@ -76,7 +76,7 @@ type Server struct { gServer *gNet // TCP或UDP模式下的服务器 multiple *MultipleServer // 多服务器模式下的服务器 ants *ants.Pool // 协程池 - messagePool *hub.Pool[*Message] // 消息池 + messagePool *hub.ObjectPool[*Message] // 消息池 ctx context.Context // 上下文 cancel context.CancelFunc // 停止上下文 systemSignal chan os.Signal // 系统信号 @@ -100,7 +100,7 @@ func (srv *Server) preCheckAndAdaptation(addr string) (startState <-chan error, kcp.SystemTimedSched.Close() } - srv.hub.run(srv.ctx) + srv.connMgr.run(srv.ctx) return srv.network.adaptation(srv), nil }