test: 完善 dispatcher.Dispatcher 注释及测试用例

This commit is contained in:
kercylan98 2024-01-12 13:48:57 +08:00
parent c439ef6424
commit a2a9199d41
5 changed files with 208 additions and 49 deletions

View File

@ -125,7 +125,7 @@ type connection struct {
gw func(packet []byte)
data map[any]any
closed bool
pool *hub.Pool[*connPacket]
pool *hub.ObjectPool[*connPacket]
loop writeloop.WriteLoop[*connPacket]
mu sync.Mutex
openTime time.Time

View File

@ -6,7 +6,7 @@ import (
"sync"
)
type hub struct {
type connMgr struct {
connections map[string]*Conn // 所有连接
register chan *Conn // 注册连接
@ -26,12 +26,12 @@ type hubBroadcast struct {
filter func(conn *Conn) bool // 过滤掉返回 false 的连接
}
func (h *hub) run(ctx context.Context) {
func (h *connMgr) run(ctx context.Context) {
h.connections = make(map[string]*Conn)
h.register = make(chan *Conn, DefaultConnHubBufferSize)
h.unregister = make(chan string, DefaultConnHubBufferSize)
h.broadcast = make(chan hubBroadcast, DefaultConnHubBufferSize)
go func(ctx context.Context, h *hub) {
go func(ctx context.Context, h *connMgr) {
for {
select {
case conn := <-h.register:
@ -54,7 +54,7 @@ func (h *hub) run(ctx context.Context) {
}
// registerConn 注册连接
func (h *hub) registerConn(conn *Conn) {
func (h *connMgr) registerConn(conn *Conn) {
select {
case h.register <- conn:
default:
@ -63,7 +63,7 @@ func (h *hub) registerConn(conn *Conn) {
}
// unregisterConn 注销连接
func (h *hub) unregisterConn(id string) {
func (h *connMgr) unregisterConn(id string) {
select {
case h.unregister <- id:
default:
@ -72,21 +72,21 @@ func (h *hub) unregisterConn(id string) {
}
// GetOnlineCount 获取在线人数
func (h *hub) GetOnlineCount() int {
func (h *connMgr) GetOnlineCount() int {
h.chanMutex.RLock()
defer h.chanMutex.RUnlock()
return h.onlineCount
}
// GetOnlineBotCount 获取在线机器人数量
func (h *hub) GetOnlineBotCount() int {
func (h *connMgr) GetOnlineBotCount() int {
h.chanMutex.RLock()
defer h.chanMutex.RUnlock()
return h.botCount
}
// IsOnline 是否在线
func (h *hub) IsOnline(id string) bool {
func (h *connMgr) IsOnline(id string) bool {
h.chanMutex.RLock()
_, exist := h.connections[id]
h.chanMutex.RUnlock()
@ -94,7 +94,7 @@ func (h *hub) IsOnline(id string) bool {
}
// GetOnlineAll 获取所有在线连接
func (h *hub) GetOnlineAll() map[string]*Conn {
func (h *connMgr) GetOnlineAll() map[string]*Conn {
h.chanMutex.RLock()
cop := collection.CloneMap(h.connections)
h.chanMutex.RUnlock()
@ -102,7 +102,7 @@ func (h *hub) GetOnlineAll() map[string]*Conn {
}
// GetOnline 获取在线连接
func (h *hub) GetOnline(id string) *Conn {
func (h *connMgr) GetOnline(id string) *Conn {
h.chanMutex.RLock()
conn := h.connections[id]
h.chanMutex.RUnlock()
@ -110,7 +110,7 @@ func (h *hub) GetOnline(id string) *Conn {
}
// CloseConn 关闭连接
func (h *hub) CloseConn(id string) {
func (h *connMgr) CloseConn(id string) {
h.chanMutex.RLock()
conn := h.connections[id]
h.chanMutex.RUnlock()
@ -120,7 +120,7 @@ func (h *hub) CloseConn(id string) {
}
// Broadcast 广播消息
func (h *hub) Broadcast(packet []byte, filter ...func(conn *Conn) bool) {
func (h *connMgr) Broadcast(packet []byte, filter ...func(conn *Conn) bool) {
m := hubBroadcast{
packet: packet,
}
@ -134,7 +134,7 @@ func (h *hub) Broadcast(packet []byte, filter ...func(conn *Conn) bool) {
}
}
func (h *hub) onRegister(conn *Conn) {
func (h *connMgr) onRegister(conn *Conn) {
h.chanMutex.Lock()
if h.closed {
conn.Close()
@ -148,7 +148,7 @@ func (h *hub) onRegister(conn *Conn) {
h.chanMutex.Unlock()
}
func (h *hub) onUnregister(id string) {
func (h *connMgr) onUnregister(id string) {
h.chanMutex.Lock()
if conn, ok := h.connections[id]; ok {
h.onlineCount--
@ -160,7 +160,7 @@ func (h *hub) onUnregister(id string) {
h.chanMutex.Unlock()
}
func (h *hub) onBroadcast(packet hubBroadcast) {
func (h *connMgr) onBroadcast(packet hubBroadcast) {
h.chanMutex.RLock()
defer h.chanMutex.RUnlock()
for _, conn := range h.connections {

View File

@ -1,6 +1,7 @@
package dispatcher
import (
"fmt"
"github.com/alphadose/haxmap"
"github.com/kercylan98/minotaur/utils/buffer"
"sync"
@ -12,8 +13,11 @@ var unique = struct{}{}
// Handler 消息处理器
type Handler[P Producer, M Message[P]] func(dispatcher *Dispatcher[P, M], message M)
// NewDispatcher 生成消息分发器
// NewDispatcher 创建一个新的消息分发器 Dispatcher 实例
func NewDispatcher[P Producer, M Message[P]](bufferSize int, name string, handler Handler[P, M]) *Dispatcher[P, M] {
if bufferSize <= 0 || handler == nil {
panic(fmt.Errorf("bufferSize must be greater than 0 and handler must not be nil, but got bufferSize: %d, handler is nil: %v", bufferSize, handler == nil))
}
d := &Dispatcher[P, M]{
name: name,
buf: buffer.NewRingUnbounded[M](bufferSize),
@ -26,7 +30,19 @@ func NewDispatcher[P Producer, M Message[P]](bufferSize int, name string, handle
return d
}
// Dispatcher 消息分发器
// Dispatcher 用于服务器消息处理的消息分发器
//
// 这个消息分发器为并发安全的生产者和消费者模型,生产者可以是任意类型,消费者必须是 Message 接口的实现。
// 生产者可以通过 Put 方法并发安全地将消息放入消息分发器,消息执行过程不会阻塞到 Put 方法,同时允许在 Start 方法之前调用 Put 方法。
// 在执行 Start 方法后,消息分发器会阻塞地从消息缓冲区中读取消息,然后执行消息处理器,消息处理器的执行过程不会阻塞到消息的生产。
//
// 为了保证消息不丢失,内部采用了 buffer.RingUnbounded 作为缓冲区实现,并且消息分发器不提供 Close 方法。
// 如果需要关闭消息分发器,可以通过 Expel 方法设置驱逐计划,当消息分发器中没有任何消息时,将会被释放。
// 同时,也可以使用 UnExpel 方法取消驱逐计划。
//
// 为什么提供 Expel 和 UnExpel 方法:
// - 在连接断开时,当需要执行一系列消息处理时,如果直接关闭消息分发器,可能会导致消息丢失。所以提供了 Expel 方法,可以在消息处理完成后再关闭消息分发器。
// - 当消息还未处理完成时连接重连,如果没有取消驱逐计划,可能会导致消息分发器被关闭。所以提供了 UnExpel 方法,可以在连接重连后取消驱逐计划。
type Dispatcher[P Producer, M Message[P]] struct {
buf *buffer.RingUnbounded[M]
uniques *haxmap.Map[string, struct{}]
@ -42,6 +58,7 @@ type Dispatcher[P Producer, M Message[P]] struct {
}
// SetProducerDoneHandler 设置特定生产者所有消息处理完成时的回调函数
// - 如果 handler 为 nil则会删除该生产者的回调函数
func (d *Dispatcher[P, M]) SetProducerDoneHandler(p P, handler func(p P, dispatcher *Dispatcher[P, M])) *Dispatcher[P, M] {
d.lock.Lock()
if handler == nil {

View File

@ -3,6 +3,7 @@ package dispatcher_test
import (
"github.com/kercylan98/minotaur/server/internal/dispatcher"
"sync"
"sync/atomic"
"testing"
"time"
)
@ -16,33 +17,174 @@ func (m *TestMessage) GetProducer() string {
return m.producer
}
func TestDispatcher_PutStartClose(t *testing.T) {
// 写入完成后,关闭分发器再开始分发,确保消息不会丢失
w := new(sync.WaitGroup)
cw := new(sync.WaitGroup)
cw.Add(1)
d := dispatcher.NewDispatcher[string, *TestMessage](1024*16, "test", func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
t.Log(message)
w.Done()
}).SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
t.Log("closed")
cw.Done()
})
for i := 0; i < 100; i++ {
w.Add(1)
d.Put(&TestMessage{
producer: "test",
v: i,
})
func TestNewDispatcher(t *testing.T) {
var cases = []struct {
name string
bufferSize int
handler dispatcher.Handler[string, *TestMessage]
shouldPanic bool
}{
{name: "TestNewDispatcher_BufferSize0AndHandlerNil", bufferSize: 0, handler: nil, shouldPanic: true},
{name: "TestNewDispatcher_BufferSize0AndHandlerNotNil", bufferSize: 0, handler: func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {}, shouldPanic: true},
{name: "TestNewDispatcher_BufferSize1AndHandlerNil", bufferSize: 1, handler: nil, shouldPanic: true},
{name: "TestNewDispatcher_BufferSize1AndHandlerNotNil", bufferSize: 1, handler: func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {}, shouldPanic: false},
}
d.Start()
d.Expel()
d.UnExpel()
w.Wait()
time.Sleep(time.Second)
d.Expel()
cw.Wait()
t.Log("done")
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
defer func() {
if r := recover(); r != nil && !c.shouldPanic {
t.Errorf("NewDispatcher() should not panic, but panic: %v", r)
}
}()
dispatcher.NewDispatcher(c.bufferSize, c.name, c.handler)
})
}
}
func TestDispatcher_SetProducerDoneHandler(t *testing.T) {
var cases = []struct {
name string
producer string
messageFinish *atomic.Bool
cancel bool
}{
{name: "TestDispatcher_SetProducerDoneHandlerNotCancel", producer: "producer", cancel: false},
{name: "TestDispatcher_SetProducerDoneHandlerCancel", producer: "producer", cancel: true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
c.messageFinish = &atomic.Bool{}
w := new(sync.WaitGroup)
d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
w.Done()
})
d.SetProducerDoneHandler("producer", func(p string, dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
c.messageFinish.Store(true)
})
if c.cancel {
d.SetProducerDoneHandler("producer", nil)
}
d.Put(&TestMessage{producer: "producer"})
w.Add(1)
d.Start()
w.Wait()
if c.messageFinish.Load() && c.cancel {
t.Errorf("%s should cancel, but not", c.name)
}
})
}
}
func TestDispatcher_SetClosedHandler(t *testing.T) {
var cases = []struct {
name string
handlerFinishMsgCount *atomic.Int64
msgTime time.Duration
msgCount int
}{
{name: "TestDispatcher_SetClosedHandler_Normal", msgTime: 0, msgCount: 1},
{name: "TestDispatcher_SetClosedHandler_MessageCount1024", msgTime: 0, msgCount: 1024},
{name: "TestDispatcher_SetClosedHandler_MessageTime1sMessageCount3", msgTime: 1 * time.Second, msgCount: 3},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
c.handlerFinishMsgCount = &atomic.Int64{}
w := new(sync.WaitGroup)
d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
time.Sleep(c.msgTime)
c.handlerFinishMsgCount.Add(1)
})
d.SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
w.Done()
})
for i := 0; i < c.msgCount; i++ {
d.Put(&TestMessage{producer: "producer"})
}
w.Add(1)
d.Start()
d.Expel()
w.Wait()
if c.handlerFinishMsgCount.Load() != int64(c.msgCount) {
t.Errorf("%s should finish %d messages, but finish %d", c.name, c.msgCount, c.handlerFinishMsgCount.Load())
}
})
}
}
func TestDispatcher_Expel(t *testing.T) {
var cases = []struct {
name string
handlerFinishMsgCount *atomic.Int64
msgTime time.Duration
msgCount int
}{
{name: "TestDispatcher_Expel_Normal", msgTime: 0, msgCount: 1},
{name: "TestDispatcher_Expel_MessageCount1024", msgTime: 0, msgCount: 1024},
{name: "TestDispatcher_Expel_MessageTime1sMessageCount3", msgTime: 1 * time.Second, msgCount: 3},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
c.handlerFinishMsgCount = &atomic.Int64{}
w := new(sync.WaitGroup)
d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
time.Sleep(c.msgTime)
c.handlerFinishMsgCount.Add(1)
})
d.SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
w.Done()
})
for i := 0; i < c.msgCount; i++ {
d.Put(&TestMessage{producer: "producer"})
}
w.Add(1)
d.Start()
d.Expel()
w.Wait()
if c.handlerFinishMsgCount.Load() != int64(c.msgCount) {
t.Errorf("%s should finish %d messages, but finish %d", c.name, c.msgCount, c.handlerFinishMsgCount.Load())
}
})
}
}
func TestDispatcher_UnExpel(t *testing.T) {
var cases = []struct {
name string
closed *atomic.Bool
isUnExpel bool
expect bool
}{
{name: "TestDispatcher_UnExpel_Normal", isUnExpel: true, expect: false},
{name: "TestDispatcher_UnExpel_NotExpel", isUnExpel: false, expect: true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
c.closed = &atomic.Bool{}
w := new(sync.WaitGroup)
d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
w.Done()
})
d.SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
c.closed.Store(true)
})
d.Put(&TestMessage{producer: "producer"})
w.Add(1)
if c.isUnExpel {
d.Expel()
d.UnExpel()
} else {
d.Expel()
}
d.Start()
w.Wait()
if c.closed.Load() != c.expect {
t.Errorf("%s should %v, but %v", c.name, c.expect, c.closed.Load())
}
})
}
}

View File

@ -36,7 +36,7 @@ func New(network Network, options ...Option) *Server {
connWriteBufferSize: DefaultConnWriteBufferSize,
dispatcherBufferSize: DefaultDispatcherBufferSize,
},
hub: &hub{},
connMgr: &connMgr{},
option: &option{},
network: network,
closeChannel: make(chan struct{}, 1),
@ -68,7 +68,7 @@ type Server struct {
*event // 事件
*runtime // 运行时
*option // 可选项
*hub // 连接集合
*connMgr // 连接集合
dispatcherMgr *dispatcher.Manager[string, *Message] // 消息分发器管理器
ginServer *gin.Engine // HTTP模式下的路由器
httpServer *http.Server // HTTP模式下的服务器
@ -76,7 +76,7 @@ type Server struct {
gServer *gNet // TCP或UDP模式下的服务器
multiple *MultipleServer // 多服务器模式下的服务器
ants *ants.Pool // 协程池
messagePool *hub.Pool[*Message] // 消息池
messagePool *hub.ObjectPool[*Message] // 消息池
ctx context.Context // 上下文
cancel context.CancelFunc // 停止上下文
systemSignal chan os.Signal // 系统信号
@ -100,7 +100,7 @@ func (srv *Server) preCheckAndAdaptation(addr string) (startState <-chan error,
kcp.SystemTimedSched.Close()
}
srv.hub.run(srv.ctx)
srv.connMgr.run(srv.ctx)
return srv.network.adaptation(srv), nil
}