refactor: 重构 server 包分流渠道设计,修复部分问题
- 使用 RingBuffer 实现分流渠道的无界缓冲区,修复分流渠道被关闭后,未处理完成的消息将会被丢弃的问题; - 移除 server.WithDisableAutomaticReleaseShunt 可选项,分流渠道将在消息处理完毕且没有连接使用时自行释放;
This commit is contained in:
parent
3402c83fd4
commit
3408c212d0
|
@ -366,7 +366,4 @@ func (slf *Conn) Close(err ...error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
slf.server.OnConnectionClosedEvent(slf, nil)
|
slf.server.OnConnectionClosedEvent(slf, nil)
|
||||||
if !slf.server.runtime.disableAutomaticReleaseShunt {
|
|
||||||
slf.server.releaseDispatcher(slf)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
const (
|
const (
|
||||||
serverMultipleMark = "Minotaur Multiple Server"
|
serverMultipleMark = "Minotaur Multiple Server"
|
||||||
serverMark = "Minotaur Server"
|
serverMark = "Minotaur Server"
|
||||||
serverSystemDispatcher = "__system" // 系统消息分发器
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -1,83 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/alphadose/haxmap"
|
|
||||||
"github.com/kercylan98/minotaur/utils/buffer"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var dispatcherUnique = struct{}{}
|
|
||||||
|
|
||||||
// generateDispatcher 生成消息分发器
|
|
||||||
func generateDispatcher(name string, handler func(dispatcher *dispatcher, message *Message)) *dispatcher {
|
|
||||||
d := &dispatcher{
|
|
||||||
name: name,
|
|
||||||
buf: buffer.NewUnbounded[*Message](),
|
|
||||||
handler: handler,
|
|
||||||
uniques: haxmap.New[string, struct{}](),
|
|
||||||
}
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
// dispatcher 消息分发器
|
|
||||||
type dispatcher struct {
|
|
||||||
name string
|
|
||||||
buf *buffer.Unbounded[*Message]
|
|
||||||
uniques *haxmap.Map[string, struct{}]
|
|
||||||
handler func(dispatcher *dispatcher, message *Message)
|
|
||||||
closed uint32
|
|
||||||
msgCount int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dispatcher) unique(name string) bool {
|
|
||||||
_, loaded := d.uniques.GetOrSet(name, dispatcherUnique)
|
|
||||||
return loaded
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dispatcher) antiUnique(name string) {
|
|
||||||
d.uniques.Del(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dispatcher) start() {
|
|
||||||
defer d.buf.Close()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case message, ok := <-d.buf.Get():
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
d.buf.Load()
|
|
||||||
d.handler(d, message)
|
|
||||||
|
|
||||||
if atomic.AddInt64(&d.msgCount, -1) <= 0 && atomic.LoadUint32(&d.closed) == 1 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dispatcher) put(message *Message) {
|
|
||||||
if atomic.CompareAndSwapUint32(&d.closed, 1, 1) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
atomic.AddInt64(&d.msgCount, 1)
|
|
||||||
d.buf.Put(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dispatcher) close() {
|
|
||||||
atomic.CompareAndSwapUint32(&d.closed, 0, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
if d.buf.IsClosed() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if atomic.LoadInt64(&d.msgCount) <= 0 {
|
|
||||||
d.buf.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
|
@ -216,6 +216,7 @@ func (slf *event) OnConnectionClosedEvent(conn *Conn, err any) {
|
||||||
value(slf.Server, conn, err)
|
value(slf.Server, conn, err)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
slf.Server.dispatcherMgr.UnBindProducer(conn.GetID())
|
||||||
}, log.String("Event", "OnConnectionClosedEvent"))
|
}, log.String("Event", "OnConnectionClosedEvent"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,153 @@
|
||||||
|
package dispatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/alphadose/haxmap"
|
||||||
|
"github.com/kercylan98/minotaur/utils/buffer"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
var unique = struct{}{}
|
||||||
|
|
||||||
|
// Handler 消息处理器
|
||||||
|
type Handler[P Producer, M Message[P]] func(dispatcher *Dispatcher[P, M], message M)
|
||||||
|
|
||||||
|
// NewDispatcher 生成消息分发器
|
||||||
|
func NewDispatcher[P Producer, M Message[P]](bufferSize int, name string, handler Handler[P, M]) *Dispatcher[P, M] {
|
||||||
|
d := &Dispatcher[P, M]{
|
||||||
|
name: name,
|
||||||
|
buf: buffer.NewRingUnbounded[M](bufferSize),
|
||||||
|
handler: handler,
|
||||||
|
uniques: haxmap.New[string, struct{}](),
|
||||||
|
pmc: make(map[P]int64),
|
||||||
|
pmcF: make(map[P]func(p P, dispatcher *Dispatcher[P, M])),
|
||||||
|
abort: make(chan struct{}),
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatcher 消息分发器
|
||||||
|
type Dispatcher[P Producer, M Message[P]] struct {
|
||||||
|
buf *buffer.RingUnbounded[M]
|
||||||
|
uniques *haxmap.Map[string, struct{}]
|
||||||
|
handler Handler[P, M]
|
||||||
|
expel bool
|
||||||
|
mc int64
|
||||||
|
pmc map[P]int64
|
||||||
|
pmcF map[P]func(p P, dispatcher *Dispatcher[P, M])
|
||||||
|
lock sync.RWMutex
|
||||||
|
name string
|
||||||
|
closedHandler atomic.Pointer[func(dispatcher *Dispatcher[P, M])]
|
||||||
|
abort chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetProducerDoneHandler 设置特定生产者所有消息处理完成时的回调函数
|
||||||
|
func (d *Dispatcher[P, M]) SetProducerDoneHandler(p P, handler func(p P, dispatcher *Dispatcher[P, M])) *Dispatcher[P, M] {
|
||||||
|
d.lock.Lock()
|
||||||
|
if handler == nil {
|
||||||
|
delete(d.pmcF, p)
|
||||||
|
} else {
|
||||||
|
d.pmcF[p] = handler
|
||||||
|
}
|
||||||
|
d.lock.Unlock()
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetClosedHandler 设置消息分发器关闭时的回调函数
|
||||||
|
func (d *Dispatcher[P, M]) SetClosedHandler(handler func(dispatcher *Dispatcher[P, M])) *Dispatcher[P, M] {
|
||||||
|
d.closedHandler.Store(&handler)
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name 获取消息分发器名称
|
||||||
|
func (d *Dispatcher[P, M]) Name() string {
|
||||||
|
return d.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unique 设置唯一消息键,返回是否已存在
|
||||||
|
func (d *Dispatcher[P, M]) Unique(name string) bool {
|
||||||
|
_, loaded := d.uniques.GetOrSet(name, unique)
|
||||||
|
return loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntiUnique 取消唯一消息键
|
||||||
|
func (d *Dispatcher[P, M]) AntiUnique(name string) {
|
||||||
|
d.uniques.Del(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expel 设置该消息分发器即将被驱逐,当消息分发器中没有任何消息时,会自动关闭
|
||||||
|
func (d *Dispatcher[P, M]) Expel() {
|
||||||
|
d.lock.Lock()
|
||||||
|
d.expel = true
|
||||||
|
if d.mc <= 0 {
|
||||||
|
d.abort <- struct{}{}
|
||||||
|
}
|
||||||
|
d.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnExpel 取消特定生产者的驱逐计划
|
||||||
|
func (d *Dispatcher[P, M]) UnExpel() {
|
||||||
|
d.lock.Lock()
|
||||||
|
d.expel = false
|
||||||
|
d.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrCount 主动增量设置特定生产者的消息计数,这在等待异步消息完成后再关闭消息分发器时非常有用
|
||||||
|
func (d *Dispatcher[P, M]) IncrCount(producer P, i int64) {
|
||||||
|
d.lock.Lock()
|
||||||
|
d.mc += i
|
||||||
|
d.pmc[producer] += i
|
||||||
|
if d.expel && d.mc <= 0 {
|
||||||
|
d.abort <- struct{}{}
|
||||||
|
}
|
||||||
|
d.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put 将消息放入分发器
|
||||||
|
func (d *Dispatcher[P, M]) Put(message M) {
|
||||||
|
d.lock.Lock()
|
||||||
|
d.mc++
|
||||||
|
d.pmc[message.GetProducer()]++
|
||||||
|
d.lock.Unlock()
|
||||||
|
d.buf.Write(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 以阻塞的方式开始进行消息分发,当消息分发器中没有任何消息时,会自动关闭
|
||||||
|
func (d *Dispatcher[P, M]) Start() *Dispatcher[P, M] {
|
||||||
|
go func(d *Dispatcher[P, M]) {
|
||||||
|
process:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-d.abort:
|
||||||
|
d.buf.Close()
|
||||||
|
break process
|
||||||
|
case message := <-d.buf.Read():
|
||||||
|
d.handler(d, message)
|
||||||
|
d.lock.Lock()
|
||||||
|
d.mc--
|
||||||
|
p := message.GetProducer()
|
||||||
|
pmc := d.pmc[p] - 1
|
||||||
|
d.pmc[p] = pmc
|
||||||
|
if f := d.pmcF[p]; f != nil && pmc <= 0 {
|
||||||
|
go f(p, d)
|
||||||
|
}
|
||||||
|
if d.mc <= 0 && d.expel {
|
||||||
|
d.buf.Close()
|
||||||
|
break process
|
||||||
|
}
|
||||||
|
d.lock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
closedHandler := *(d.closedHandler.Load())
|
||||||
|
if closedHandler != nil {
|
||||||
|
closedHandler(d)
|
||||||
|
}
|
||||||
|
close(d.abort)
|
||||||
|
}(d)
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Closed 判断消息分发器是否已关闭
|
||||||
|
func (d *Dispatcher[P, M]) Closed() bool {
|
||||||
|
return d.buf.Closed()
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
package dispatcher_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/kercylan98/minotaur/server/internal/dispatcher"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TestMessage struct {
|
||||||
|
producer string
|
||||||
|
v int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TestMessage) GetProducer() string {
|
||||||
|
return m.producer
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDispatcher_PutStartClose(t *testing.T) {
|
||||||
|
// 写入完成后,关闭分发器再开始分发,确保消息不会丢失
|
||||||
|
w := new(sync.WaitGroup)
|
||||||
|
cw := new(sync.WaitGroup)
|
||||||
|
cw.Add(1)
|
||||||
|
d := dispatcher.NewDispatcher[string, *TestMessage](1024*16, "test", func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
|
||||||
|
t.Log(message)
|
||||||
|
w.Done()
|
||||||
|
}).SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
|
||||||
|
t.Log("closed")
|
||||||
|
cw.Done()
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
w.Add(1)
|
||||||
|
d.Put(&TestMessage{
|
||||||
|
producer: "test",
|
||||||
|
v: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Start()
|
||||||
|
d.Expel()
|
||||||
|
d.UnExpel()
|
||||||
|
w.Wait()
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
d.Expel()
|
||||||
|
cw.Wait()
|
||||||
|
t.Log("done")
|
||||||
|
}
|
|
@ -0,0 +1,153 @@
|
||||||
|
package dispatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const SystemName = "*system"
|
||||||
|
|
||||||
|
// NewManager 生成消息分发器管理器
|
||||||
|
func NewManager[P Producer, M Message[P]](bufferSize int, handler Handler[P, M]) *Manager[P, M] {
|
||||||
|
mgr := &Manager[P, M]{
|
||||||
|
handler: handler,
|
||||||
|
dispatchers: make(map[string]*Dispatcher[P, M]),
|
||||||
|
member: make(map[string]map[P]struct{}),
|
||||||
|
sys: NewDispatcher(bufferSize, SystemName, handler).Start(),
|
||||||
|
curr: make(map[P]*Dispatcher[P, M]),
|
||||||
|
size: bufferSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager 消息分发器管理器
|
||||||
|
type Manager[P Producer, M Message[P]] struct {
|
||||||
|
handler Handler[P, M] // 消息处理器
|
||||||
|
sys *Dispatcher[P, M] // 系统消息分发器
|
||||||
|
dispatchers map[string]*Dispatcher[P, M] // 当前所有正在工作的消息分发器
|
||||||
|
member map[string]map[P]struct{} // 当前正在工作的消息分发器对应的生产者
|
||||||
|
curr map[P]*Dispatcher[P, M] // 当前特定生产者正在使用的消息分发器
|
||||||
|
lock sync.RWMutex // 消息分发器锁
|
||||||
|
w sync.WaitGroup // 消息分发器等待组
|
||||||
|
size int // 消息分发器缓冲区大小
|
||||||
|
|
||||||
|
closedHandler func(name string)
|
||||||
|
createdHandler func(name string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDispatcherClosedHandler 设置消息分发器关闭时的回调函数
|
||||||
|
func (m *Manager[P, M]) SetDispatcherClosedHandler(handler func(name string)) *Manager[P, M] {
|
||||||
|
m.closedHandler = handler
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDispatcherCreatedHandler 设置消息分发器创建时的回调函数
|
||||||
|
func (m *Manager[P, M]) SetDispatcherCreatedHandler(handler func(name string)) *Manager[P, M] {
|
||||||
|
m.createdHandler = handler
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasDispatcher 检查是否存在指定名称的消息分发器
|
||||||
|
func (m *Manager[P, M]) HasDispatcher(name string) bool {
|
||||||
|
m.lock.RLock()
|
||||||
|
defer m.lock.RUnlock()
|
||||||
|
_, exist := m.dispatchers[name]
|
||||||
|
return exist
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDispatcherNum 获取当前正在工作的消息分发器数量
|
||||||
|
func (m *Manager[P, M]) GetDispatcherNum() int {
|
||||||
|
m.lock.RLock()
|
||||||
|
defer m.lock.RUnlock()
|
||||||
|
return len(m.dispatchers) + 1 // +1 系统消息分发器
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSystemDispatcher 获取系统消息分发器
|
||||||
|
func (m *Manager[P, M]) GetSystemDispatcher() *Dispatcher[P, M] {
|
||||||
|
return m.sys
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDispatcher 获取生产者正在使用的消息分发器,如果生产者没有绑定消息分发器,则会返回系统消息分发器
|
||||||
|
func (m *Manager[P, M]) GetDispatcher(p P) *Dispatcher[P, M] {
|
||||||
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
|
||||||
|
curr, exist := m.curr[p]
|
||||||
|
if exist {
|
||||||
|
return curr
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.sys
|
||||||
|
}
|
||||||
|
|
||||||
|
// BindProducer 绑定生产者使用特定的消息分发器,如果生产者已经绑定了消息分发器,则会先解绑
|
||||||
|
func (m *Manager[P, M]) BindProducer(p P, name string) {
|
||||||
|
if name == SystemName {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
member, exist := m.member[name]
|
||||||
|
if !exist {
|
||||||
|
member = make(map[P]struct{})
|
||||||
|
m.member[name] = member
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exist = member[p]; exist {
|
||||||
|
d := m.dispatchers[name]
|
||||||
|
d.SetProducerDoneHandler(p, nil)
|
||||||
|
d.UnExpel()
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
curr, exist := m.curr[p]
|
||||||
|
if exist {
|
||||||
|
delete(m.member[curr.name], p)
|
||||||
|
if len(m.member[curr.name]) == 0 {
|
||||||
|
curr.Expel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatcher, exist := m.dispatchers[name]
|
||||||
|
if !exist {
|
||||||
|
dispatcher = NewDispatcher(m.size, name, m.handler).SetClosedHandler(func(dispatcher *Dispatcher[P, M]) {
|
||||||
|
// 消息分发器关闭时,将会将其从管理器中移除
|
||||||
|
m.lock.Lock()
|
||||||
|
delete(m.dispatchers, dispatcher.name)
|
||||||
|
delete(m.member, dispatcher.name)
|
||||||
|
m.lock.Unlock()
|
||||||
|
if m.closedHandler != nil {
|
||||||
|
m.closedHandler(dispatcher.name)
|
||||||
|
}
|
||||||
|
}).Start()
|
||||||
|
m.dispatchers[name] = dispatcher
|
||||||
|
defer func(m *Manager[P, M], name string) {
|
||||||
|
if m.createdHandler != nil {
|
||||||
|
m.createdHandler(name)
|
||||||
|
}
|
||||||
|
}(m, dispatcher.Name())
|
||||||
|
}
|
||||||
|
m.curr[p] = dispatcher
|
||||||
|
member[p] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnBindProducer 解绑生产者使用特定的消息分发器
|
||||||
|
func (m *Manager[P, M]) UnBindProducer(p P) {
|
||||||
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
curr, exist := m.curr[p]
|
||||||
|
if !exist {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
curr.SetProducerDoneHandler(p, func(p P, dispatcher *Dispatcher[P, M]) {
|
||||||
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
delete(m.member[dispatcher.name], p)
|
||||||
|
delete(m.curr, p)
|
||||||
|
if len(m.member[dispatcher.name]) == 0 {
|
||||||
|
dispatcher.Expel()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,46 @@
|
||||||
|
package dispatcher_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/kercylan98/minotaur/server/internal/dispatcher"
|
||||||
|
"github.com/kercylan98/minotaur/utils/times"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManager(t *testing.T) {
|
||||||
|
var mgr *dispatcher.Manager[string, *TestMessage]
|
||||||
|
var onHandler = func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
|
||||||
|
t.Log(dispatcher.Name(), message, mgr.GetDispatcherNum())
|
||||||
|
switch message.v {
|
||||||
|
case 4:
|
||||||
|
mgr.UnBindProducer("test")
|
||||||
|
t.Log("UnBindProducer")
|
||||||
|
case 6:
|
||||||
|
mgr.BindProducer(message.GetProducer(), "test-dispatcher")
|
||||||
|
t.Log("BindProducer")
|
||||||
|
case 9:
|
||||||
|
dispatcher.Put(&TestMessage{
|
||||||
|
producer: "test",
|
||||||
|
v: 10,
|
||||||
|
})
|
||||||
|
case 10:
|
||||||
|
mgr.UnBindProducer("test")
|
||||||
|
t.Log("UnBindProducer", mgr.GetDispatcherNum())
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
mgr = dispatcher.NewManager[string, *TestMessage](1024*16, onHandler)
|
||||||
|
|
||||||
|
mgr.BindProducer("test", "test-dispatcher")
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
d := mgr.GetDispatcher("test").SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
|
||||||
|
t.Log("closed")
|
||||||
|
})
|
||||||
|
d.Put(&TestMessage{
|
||||||
|
producer: "test",
|
||||||
|
v: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(times.Day)
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
package dispatcher
|
||||||
|
|
||||||
|
type Message[P comparable] interface {
|
||||||
|
// GetProducer 获取消息生产者
|
||||||
|
GetProducer() P
|
||||||
|
}
|
|
@ -0,0 +1,5 @@
|
||||||
|
package dispatcher
|
||||||
|
|
||||||
|
type Producer interface {
|
||||||
|
comparable
|
||||||
|
}
|
|
@ -75,6 +75,7 @@ func HasMessageType(mt MessageType) bool {
|
||||||
|
|
||||||
// Message 服务器消息
|
// Message 服务器消息
|
||||||
type Message struct {
|
type Message struct {
|
||||||
|
producer string
|
||||||
conn *Conn
|
conn *Conn
|
||||||
ordinaryHandler func()
|
ordinaryHandler func()
|
||||||
exceptionHandler func() error
|
exceptionHandler func() error
|
||||||
|
@ -86,6 +87,10 @@ type Message struct {
|
||||||
marks []log.Field
|
marks []log.Field
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (slf *Message) GetProducer() string {
|
||||||
|
return slf.producer
|
||||||
|
}
|
||||||
|
|
||||||
// reset 重置消息结构体
|
// reset 重置消息结构体
|
||||||
func (slf *Message) reset() {
|
func (slf *Message) reset() {
|
||||||
slf.conn = nil
|
slf.conn = nil
|
||||||
|
@ -126,78 +131,91 @@ func (slf MessageType) String() string {
|
||||||
|
|
||||||
// castToPacketMessage 将消息转换为数据包消息
|
// castToPacketMessage 将消息转换为数据包消息
|
||||||
func (slf *Message) castToPacketMessage(conn *Conn, packet []byte, mark ...log.Field) *Message {
|
func (slf *Message) castToPacketMessage(conn *Conn, packet []byte, mark ...log.Field) *Message {
|
||||||
|
slf.producer = conn.GetID()
|
||||||
slf.t, slf.conn, slf.packet, slf.marks = MessageTypePacket, conn, packet, mark
|
slf.t, slf.conn, slf.packet, slf.marks = MessageTypePacket, conn, packet, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToTickerMessage 将消息转换为定时器消息
|
// castToTickerMessage 将消息转换为定时器消息
|
||||||
func (slf *Message) castToTickerMessage(name string, caller func(), mark ...log.Field) *Message {
|
func (slf *Message) castToTickerMessage(name string, caller func(), mark ...log.Field) *Message {
|
||||||
|
slf.producer = "sys"
|
||||||
slf.t, slf.name, slf.ordinaryHandler, slf.marks = MessageTypeTicker, name, caller, mark
|
slf.t, slf.name, slf.ordinaryHandler, slf.marks = MessageTypeTicker, name, caller, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToShuntTickerMessage 将消息转换为分发器定时器消息
|
// castToShuntTickerMessage 将消息转换为分发器定时器消息
|
||||||
func (slf *Message) castToShuntTickerMessage(conn *Conn, name string, caller func(), mark ...log.Field) *Message {
|
func (slf *Message) castToShuntTickerMessage(conn *Conn, name string, caller func(), mark ...log.Field) *Message {
|
||||||
|
slf.producer = conn.GetID()
|
||||||
slf.t, slf.conn, slf.name, slf.ordinaryHandler, slf.marks = MessageTypeShuntTicker, conn, name, caller, mark
|
slf.t, slf.conn, slf.name, slf.ordinaryHandler, slf.marks = MessageTypeShuntTicker, conn, name, caller, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToAsyncMessage 将消息转换为异步消息
|
// castToAsyncMessage 将消息转换为异步消息
|
||||||
func (slf *Message) castToAsyncMessage(caller func() error, callback func(err error), mark ...log.Field) *Message {
|
func (slf *Message) castToAsyncMessage(caller func() error, callback func(err error), mark ...log.Field) *Message {
|
||||||
|
slf.producer = "sys"
|
||||||
slf.t, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeAsync, caller, callback, mark
|
slf.t, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeAsync, caller, callback, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToAsyncCallbackMessage 将消息转换为异步回调消息
|
// castToAsyncCallbackMessage 将消息转换为异步回调消息
|
||||||
func (slf *Message) castToAsyncCallbackMessage(err error, caller func(err error), mark ...log.Field) *Message {
|
func (slf *Message) castToAsyncCallbackMessage(err error, caller func(err error), mark ...log.Field) *Message {
|
||||||
|
slf.producer = "sys"
|
||||||
slf.t, slf.err, slf.errHandler, slf.marks = MessageTypeAsyncCallback, err, caller, mark
|
slf.t, slf.err, slf.errHandler, slf.marks = MessageTypeAsyncCallback, err, caller, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToShuntAsyncMessage 将消息转换为分流异步消息
|
// castToShuntAsyncMessage 将消息转换为分流异步消息
|
||||||
func (slf *Message) castToShuntAsyncMessage(conn *Conn, caller func() error, callback func(err error), mark ...log.Field) *Message {
|
func (slf *Message) castToShuntAsyncMessage(conn *Conn, caller func() error, callback func(err error), mark ...log.Field) *Message {
|
||||||
|
slf.producer = conn.GetID()
|
||||||
slf.t, slf.conn, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeShuntAsync, conn, caller, callback, mark
|
slf.t, slf.conn, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeShuntAsync, conn, caller, callback, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToShuntAsyncCallbackMessage 将消息转换为分流异步回调消息
|
// castToShuntAsyncCallbackMessage 将消息转换为分流异步回调消息
|
||||||
func (slf *Message) castToShuntAsyncCallbackMessage(conn *Conn, err error, caller func(err error), mark ...log.Field) *Message {
|
func (slf *Message) castToShuntAsyncCallbackMessage(conn *Conn, err error, caller func(err error), mark ...log.Field) *Message {
|
||||||
|
slf.producer = conn.GetID()
|
||||||
slf.t, slf.conn, slf.err, slf.errHandler, slf.marks = MessageTypeShuntAsyncCallback, conn, err, caller, mark
|
slf.t, slf.conn, slf.err, slf.errHandler, slf.marks = MessageTypeShuntAsyncCallback, conn, err, caller, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToUniqueAsyncMessage 将消息转换为唯一异步消息
|
// castToUniqueAsyncMessage 将消息转换为唯一异步消息
|
||||||
func (slf *Message) castToUniqueAsyncMessage(unique string, caller func() error, callback func(err error), mark ...log.Field) *Message {
|
func (slf *Message) castToUniqueAsyncMessage(unique string, caller func() error, callback func(err error), mark ...log.Field) *Message {
|
||||||
|
slf.producer = "sys"
|
||||||
slf.t, slf.name, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeUniqueAsync, unique, caller, callback, mark
|
slf.t, slf.name, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeUniqueAsync, unique, caller, callback, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToUniqueAsyncCallbackMessage 将消息转换为唯一异步回调消息
|
// castToUniqueAsyncCallbackMessage 将消息转换为唯一异步回调消息
|
||||||
func (slf *Message) castToUniqueAsyncCallbackMessage(unique string, err error, caller func(err error), mark ...log.Field) *Message {
|
func (slf *Message) castToUniqueAsyncCallbackMessage(unique string, err error, caller func(err error), mark ...log.Field) *Message {
|
||||||
|
slf.producer = "sys"
|
||||||
slf.t, slf.name, slf.err, slf.errHandler, slf.marks = MessageTypeUniqueAsyncCallback, unique, err, caller, mark
|
slf.t, slf.name, slf.err, slf.errHandler, slf.marks = MessageTypeUniqueAsyncCallback, unique, err, caller, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToUniqueShuntAsyncMessage 将消息转换为唯一分流异步消息
|
// castToUniqueShuntAsyncMessage 将消息转换为唯一分流异步消息
|
||||||
func (slf *Message) castToUniqueShuntAsyncMessage(conn *Conn, unique string, caller func() error, callback func(err error), mark ...log.Field) *Message {
|
func (slf *Message) castToUniqueShuntAsyncMessage(conn *Conn, unique string, caller func() error, callback func(err error), mark ...log.Field) *Message {
|
||||||
|
slf.producer = conn.GetID()
|
||||||
slf.t, slf.conn, slf.name, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeUniqueShuntAsync, conn, unique, caller, callback, mark
|
slf.t, slf.conn, slf.name, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeUniqueShuntAsync, conn, unique, caller, callback, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToUniqueShuntAsyncCallbackMessage 将消息转换为唯一分流异步回调消息
|
// castToUniqueShuntAsyncCallbackMessage 将消息转换为唯一分流异步回调消息
|
||||||
func (slf *Message) castToUniqueShuntAsyncCallbackMessage(conn *Conn, unique string, err error, caller func(err error), mark ...log.Field) *Message {
|
func (slf *Message) castToUniqueShuntAsyncCallbackMessage(conn *Conn, unique string, err error, caller func(err error), mark ...log.Field) *Message {
|
||||||
|
slf.producer = conn.GetID()
|
||||||
slf.t, slf.conn, slf.name, slf.err, slf.errHandler, slf.marks = MessageTypeUniqueShuntAsyncCallback, conn, unique, err, caller, mark
|
slf.t, slf.conn, slf.name, slf.err, slf.errHandler, slf.marks = MessageTypeUniqueShuntAsyncCallback, conn, unique, err, caller, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToSystemMessage 将消息转换为系统消息
|
// castToSystemMessage 将消息转换为系统消息
|
||||||
func (slf *Message) castToSystemMessage(caller func(), mark ...log.Field) *Message {
|
func (slf *Message) castToSystemMessage(caller func(), mark ...log.Field) *Message {
|
||||||
|
slf.producer = "sys"
|
||||||
slf.t, slf.ordinaryHandler, slf.marks = MessageTypeSystem, caller, mark
|
slf.t, slf.ordinaryHandler, slf.marks = MessageTypeSystem, caller, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
||||||
// castToShuntMessage 将消息转换为分流消息
|
// castToShuntMessage 将消息转换为分流消息
|
||||||
func (slf *Message) castToShuntMessage(conn *Conn, caller func(), mark ...log.Field) *Message {
|
func (slf *Message) castToShuntMessage(conn *Conn, caller func(), mark ...log.Field) *Message {
|
||||||
|
slf.producer = conn.GetID()
|
||||||
slf.t, slf.conn, slf.ordinaryHandler, slf.marks = MessageTypeShunt, conn, caller, mark
|
slf.t, slf.conn, slf.ordinaryHandler, slf.marks = MessageTypeShunt, conn, caller, mark
|
||||||
return slf
|
return slf
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,9 +49,9 @@ type runtime struct {
|
||||||
messageStatistics []*atomic.Int64 // 消息统计数量
|
messageStatistics []*atomic.Int64 // 消息统计数量
|
||||||
messageStatisticsLock *sync.RWMutex // 消息统计锁
|
messageStatisticsLock *sync.RWMutex // 消息统计锁
|
||||||
connWriteBufferSize int // 连接写入缓冲区大小
|
connWriteBufferSize int // 连接写入缓冲区大小
|
||||||
disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道
|
|
||||||
websocketUpgrader *websocket.Upgrader // websocket 升级器
|
websocketUpgrader *websocket.Upgrader // websocket 升级器
|
||||||
websocketConnInitializer func(writer http.ResponseWriter, request *http.Request, conn *websocket.Conn) error // websocket 连接初始化
|
websocketConnInitializer func(writer http.ResponseWriter, request *http.Request, conn *websocket.Conn) error // websocket 连接初始化
|
||||||
|
dispatcherBufferSize int // 消息分发器缓冲区大小
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithWebsocketConnInitializer 通过 websocket 连接初始化的方式创建服务器,当 initializer 返回错误时,服务器将不会处理该连接的后续逻辑
|
// WithWebsocketConnInitializer 通过 websocket 连接初始化的方式创建服务器,当 initializer 返回错误时,服务器将不会处理该连接的后续逻辑
|
||||||
|
@ -77,14 +77,6 @@ func WithWebsocketUpgrade(upgrader *websocket.Upgrader) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器
|
|
||||||
// - 默认不开启,当禁用自动释放分流渠道时,服务器将不会在连接断开时自动释放分流渠道,需要手动调用 ReleaseShunt 方法释放
|
|
||||||
func WithDisableAutomaticReleaseShunt() Option {
|
|
||||||
return func(srv *Server) {
|
|
||||||
srv.runtime.disableAutomaticReleaseShunt = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithConnWriteBufferSize 通过连接写入缓冲区大小的方式创建服务器
|
// WithConnWriteBufferSize 通过连接写入缓冲区大小的方式创建服务器
|
||||||
// - 默认值为 DefaultConnWriteBufferSize
|
// - 默认值为 DefaultConnWriteBufferSize
|
||||||
// - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存
|
// - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存
|
||||||
|
@ -100,14 +92,14 @@ func WithConnWriteBufferSize(size int) Option {
|
||||||
// WithDispatcherBufferSize 通过消息分发器缓冲区大小的方式创建服务器
|
// WithDispatcherBufferSize 通过消息分发器缓冲区大小的方式创建服务器
|
||||||
// - 默认值为 DefaultDispatcherBufferSize
|
// - 默认值为 DefaultDispatcherBufferSize
|
||||||
// - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存
|
// - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存
|
||||||
//func WithDispatcherBufferSize(size int) Option {
|
func WithDispatcherBufferSize(size int) Option {
|
||||||
// return func(srv *Server) {
|
return func(srv *Server) {
|
||||||
// if size <= 0 {
|
if size <= 0 {
|
||||||
// return
|
return
|
||||||
// }
|
}
|
||||||
// srv.dispatcherBufferSize = size
|
srv.dispatcherBufferSize = size
|
||||||
// }
|
}
|
||||||
//}
|
}
|
||||||
|
|
||||||
// WithMessageStatistics 通过消息统计的方式创建服务器
|
// WithMessageStatistics 通过消息统计的方式创建服务器
|
||||||
// - 默认不开启,当 duration 和 limit 均大于 0 的时候,服务器将记录每 duration 期间的消息数量,并保留最多 limit 条
|
// - 默认不开启,当 duration 和 limit 均大于 0 的时候,服务器将记录每 duration 期间的消息数量,并保留最多 limit 条
|
||||||
|
|
153
server/server.go
153
server/server.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/kercylan98/minotaur/server/internal/dispatcher"
|
||||||
"github.com/kercylan98/minotaur/server/internal/logger"
|
"github.com/kercylan98/minotaur/server/internal/logger"
|
||||||
"github.com/kercylan98/minotaur/utils/concurrent"
|
"github.com/kercylan98/minotaur/utils/concurrent"
|
||||||
"github.com/kercylan98/minotaur/utils/log"
|
"github.com/kercylan98/minotaur/utils/log"
|
||||||
|
@ -21,7 +22,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
@ -34,15 +34,13 @@ func New(network Network, options ...Option) *Server {
|
||||||
runtime: &runtime{
|
runtime: &runtime{
|
||||||
packetWarnSize: DefaultPacketWarnSize,
|
packetWarnSize: DefaultPacketWarnSize,
|
||||||
connWriteBufferSize: DefaultConnWriteBufferSize,
|
connWriteBufferSize: DefaultConnWriteBufferSize,
|
||||||
|
dispatcherBufferSize: DefaultDispatcherBufferSize,
|
||||||
},
|
},
|
||||||
hub: &hub{},
|
hub: &hub{},
|
||||||
option: &option{},
|
option: &option{},
|
||||||
network: network,
|
network: network,
|
||||||
closeChannel: make(chan struct{}, 1),
|
closeChannel: make(chan struct{}, 1),
|
||||||
systemSignal: make(chan os.Signal, 1),
|
systemSignal: make(chan os.Signal, 1),
|
||||||
dispatchers: make(map[string]*dispatcher),
|
|
||||||
dispatcherMember: map[string]map[string]*Conn{},
|
|
||||||
currDispatcher: map[string]*dispatcher{},
|
|
||||||
}
|
}
|
||||||
server.ctx, server.cancel = context.WithCancel(context.Background())
|
server.ctx, server.cancel = context.WithCancel(context.Background())
|
||||||
server.event = newEvent(server)
|
server.event = newEvent(server)
|
||||||
|
@ -71,6 +69,7 @@ type Server struct {
|
||||||
*runtime // 运行时
|
*runtime // 运行时
|
||||||
*option // 可选项
|
*option // 可选项
|
||||||
*hub // 连接集合
|
*hub // 连接集合
|
||||||
|
dispatcherMgr *dispatcher.Manager[string, *Message] // 消息分发器管理器
|
||||||
ginServer *gin.Engine // HTTP模式下的路由器
|
ginServer *gin.Engine // HTTP模式下的路由器
|
||||||
httpServer *http.Server // HTTP模式下的服务器
|
httpServer *http.Server // HTTP模式下的服务器
|
||||||
grpcServer *grpc.Server // GRPC模式下的服务器
|
grpcServer *grpc.Server // GRPC模式下的服务器
|
||||||
|
@ -80,14 +79,10 @@ type Server struct {
|
||||||
messagePool *concurrent.Pool[*Message] // 消息池
|
messagePool *concurrent.Pool[*Message] // 消息池
|
||||||
ctx context.Context // 上下文
|
ctx context.Context // 上下文
|
||||||
cancel context.CancelFunc // 停止上下文
|
cancel context.CancelFunc // 停止上下文
|
||||||
systemDispatcher *dispatcher // 系统消息分发器
|
|
||||||
systemSignal chan os.Signal // 系统信号
|
systemSignal chan os.Signal // 系统信号
|
||||||
closeChannel chan struct{} // 关闭信号
|
closeChannel chan struct{} // 关闭信号
|
||||||
multipleRuntimeErrorChan chan error // 多服务器模式下的运行时错误
|
multipleRuntimeErrorChan chan error // 多服务器模式下的运行时错误
|
||||||
dispatchers map[string]*dispatcher // 消息分发器集合
|
|
||||||
dispatcherMember map[string]map[string]*Conn // 消息分发器包含的连接
|
|
||||||
currDispatcher map[string]*dispatcher // 当前连接所处消息分发器
|
|
||||||
dispatcherLock sync.RWMutex // 消息分发器锁
|
|
||||||
messageCounter atomic.Int64 // 消息计数器
|
messageCounter atomic.Int64 // 消息计数器
|
||||||
addr string // 侦听地址
|
addr string // 侦听地址
|
||||||
network Network // 网络类型
|
network Network // 网络类型
|
||||||
|
@ -221,13 +216,6 @@ func (srv *Server) shutdown(err error) {
|
||||||
srv.ants.Release()
|
srv.ants.Release()
|
||||||
srv.ants = nil
|
srv.ants = nil
|
||||||
}
|
}
|
||||||
srv.dispatcherLock.Lock()
|
|
||||||
for s, d := range srv.dispatchers {
|
|
||||||
srv.OnShuntChannelClosedEvent(d.name)
|
|
||||||
d.close()
|
|
||||||
delete(srv.dispatchers, s)
|
|
||||||
}
|
|
||||||
srv.dispatcherLock.Unlock()
|
|
||||||
if srv.grpcServer != nil {
|
if srv.grpcServer != nil {
|
||||||
srv.grpcServer.GracefulStop()
|
srv.grpcServer.GracefulStop()
|
||||||
}
|
}
|
||||||
|
@ -300,107 +288,27 @@ func (srv *Server) GetMessageCount() int64 {
|
||||||
|
|
||||||
// UseShunt 切换连接所使用的消息分流渠道,当分流渠道 name 不存在时将会创建一个新的分流渠道,否则将会加入已存在的分流渠道
|
// UseShunt 切换连接所使用的消息分流渠道,当分流渠道 name 不存在时将会创建一个新的分流渠道,否则将会加入已存在的分流渠道
|
||||||
// - 默认情况下,所有连接都使用系统通道进行消息分发,当指定消息分流渠道时,将会使用指定的消息分流渠道进行消息分发
|
// - 默认情况下,所有连接都使用系统通道进行消息分发,当指定消息分流渠道时,将会使用指定的消息分流渠道进行消息分发
|
||||||
// - 在使用 WithDisableAutomaticReleaseShunt 创建服务器后,必须始终在连接不再使用后主动通过 ReleaseShunt 释放消息分流渠道,否则将造成内存泄漏
|
// - 分流渠道会在连接断开时标记为驱逐状态,当分流渠道中的所有消息处理完毕且没有新连接使用时,将会被清除
|
||||||
|
//
|
||||||
|
// 一些有趣的情况:
|
||||||
|
// - 当连接发送异步消息时,消息会被分为两部分,分别是异步部分和回调部分。异步部分会在当前的分流渠道中处理,而回调部分则是根据回调时所在的分流渠道进行处理
|
||||||
func (srv *Server) UseShunt(conn *Conn, name string) {
|
func (srv *Server) UseShunt(conn *Conn, name string) {
|
||||||
srv.dispatcherLock.Lock()
|
srv.dispatcherMgr.BindProducer(conn.GetID(), name)
|
||||||
defer srv.dispatcherLock.Unlock()
|
|
||||||
d, exist := srv.dispatchers[name]
|
|
||||||
if !exist {
|
|
||||||
d = generateDispatcher(name, srv.dispatchMessage)
|
|
||||||
srv.OnShuntChannelCreatedEvent(d.name)
|
|
||||||
go d.start()
|
|
||||||
srv.dispatchers[name] = d
|
|
||||||
}
|
|
||||||
|
|
||||||
curr, exist := srv.currDispatcher[conn.GetID()]
|
|
||||||
if exist {
|
|
||||||
if curr.name == name {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(srv.dispatcherMember[curr.name], conn.GetID())
|
|
||||||
if curr.name != serverSystemDispatcher && len(srv.dispatcherMember[curr.name]) == 0 {
|
|
||||||
delete(srv.dispatchers, curr.name)
|
|
||||||
srv.OnShuntChannelClosedEvent(d.name)
|
|
||||||
curr.close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
srv.currDispatcher[conn.GetID()] = d
|
|
||||||
|
|
||||||
member, exist := srv.dispatcherMember[name]
|
|
||||||
if !exist {
|
|
||||||
member = map[string]*Conn{}
|
|
||||||
srv.dispatcherMember[name] = member
|
|
||||||
}
|
|
||||||
|
|
||||||
member[conn.GetID()] = conn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasShunt 检查特定消息分流渠道是否存在
|
// HasShunt 检查特定消息分流渠道是否存在
|
||||||
func (srv *Server) HasShunt(name string) bool {
|
func (srv *Server) HasShunt(name string) bool {
|
||||||
srv.dispatcherLock.RLock()
|
return srv.dispatcherMgr.HasDispatcher(name)
|
||||||
defer srv.dispatcherLock.RUnlock()
|
|
||||||
_, exist := srv.dispatchers[name]
|
|
||||||
return exist
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnCurrShunt 获取连接当前所使用的消息分流渠道
|
// GetConnCurrShunt 获取连接当前所使用的消息分流渠道
|
||||||
func (srv *Server) GetConnCurrShunt(conn *Conn) string {
|
func (srv *Server) GetConnCurrShunt(conn *Conn) string {
|
||||||
srv.dispatcherLock.RLock()
|
return srv.dispatcherMgr.GetDispatcher(conn.GetID()).Name()
|
||||||
defer srv.dispatcherLock.RUnlock()
|
|
||||||
d, exist := srv.currDispatcher[conn.GetID()]
|
|
||||||
if exist {
|
|
||||||
return d.name
|
|
||||||
}
|
|
||||||
return serverSystemDispatcher
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetShuntNum 获取消息分流渠道数量
|
// GetShuntNum 获取消息分流渠道数量
|
||||||
func (srv *Server) GetShuntNum() int {
|
func (srv *Server) GetShuntNum() int {
|
||||||
srv.dispatcherLock.RLock()
|
return srv.dispatcherMgr.GetDispatcherNum()
|
||||||
defer srv.dispatcherLock.RUnlock()
|
|
||||||
return len(srv.dispatchers)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getConnDispatcher 获取连接所使用的消息分发器
|
|
||||||
func (srv *Server) getConnDispatcher(conn *Conn) *dispatcher {
|
|
||||||
if conn == nil {
|
|
||||||
return srv.systemDispatcher
|
|
||||||
}
|
|
||||||
srv.dispatcherLock.RLock()
|
|
||||||
defer srv.dispatcherLock.RUnlock()
|
|
||||||
d, exist := srv.currDispatcher[conn.GetID()]
|
|
||||||
if exist {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
return srv.systemDispatcher
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReleaseShunt 释放分流渠道中的连接,当分流渠道中不再存在连接时将会自动释放分流渠道
|
|
||||||
// - 在未使用 WithDisableAutomaticReleaseShunt 选项时,当连接关闭时将会自动释放分流渠道中连接的资源占用
|
|
||||||
// - 若执行过程中连接正在使用,将会切换至系统通道
|
|
||||||
func (srv *Server) ReleaseShunt(conn *Conn) {
|
|
||||||
srv.releaseDispatcher(conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// releaseDispatcher 关闭消息分发器
|
|
||||||
func (srv *Server) releaseDispatcher(conn *Conn) {
|
|
||||||
if conn == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cid := conn.GetID()
|
|
||||||
srv.dispatcherLock.Lock()
|
|
||||||
defer srv.dispatcherLock.Unlock()
|
|
||||||
d, exist := srv.currDispatcher[cid]
|
|
||||||
if exist && d.name != serverSystemDispatcher {
|
|
||||||
delete(srv.dispatcherMember[d.name], cid)
|
|
||||||
if len(srv.dispatcherMember[d.name]) == 0 {
|
|
||||||
srv.OnShuntChannelClosedEvent(d.name)
|
|
||||||
d.close()
|
|
||||||
delete(srv.dispatchers, d.name)
|
|
||||||
}
|
|
||||||
delete(srv.currDispatcher, cid)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// pushMessage 向服务器中写入特定类型的消息,需严格遵守消息属性要求
|
// pushMessage 向服务器中写入特定类型的消息,需严格遵守消息属性要求
|
||||||
|
@ -409,25 +317,29 @@ func (srv *Server) pushMessage(message *Message) {
|
||||||
srv.messagePool.Release(message)
|
srv.messagePool.Release(message)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var dispatcher *dispatcher
|
var d *dispatcher.Dispatcher[string, *Message]
|
||||||
switch message.t {
|
switch message.t {
|
||||||
case MessageTypePacket,
|
case MessageTypePacket,
|
||||||
MessageTypeShuntTicker, MessageTypeShuntAsync, MessageTypeShuntAsyncCallback,
|
MessageTypeShuntTicker, MessageTypeShuntAsync, MessageTypeShuntAsyncCallback,
|
||||||
MessageTypeUniqueShuntAsync, MessageTypeUniqueShuntAsyncCallback,
|
MessageTypeUniqueShuntAsync, MessageTypeUniqueShuntAsyncCallback,
|
||||||
MessageTypeShunt:
|
MessageTypeShunt:
|
||||||
dispatcher = srv.getConnDispatcher(message.conn)
|
d = srv.dispatcherMgr.GetDispatcher(message.conn.GetID())
|
||||||
case MessageTypeSystem, MessageTypeAsync, MessageTypeUniqueAsync, MessageTypeAsyncCallback, MessageTypeUniqueAsyncCallback, MessageTypeTicker:
|
case MessageTypeSystem, MessageTypeAsync, MessageTypeUniqueAsync, MessageTypeAsyncCallback, MessageTypeUniqueAsyncCallback, MessageTypeTicker:
|
||||||
dispatcher = srv.systemDispatcher
|
d = srv.dispatcherMgr.GetSystemDispatcher()
|
||||||
}
|
}
|
||||||
if dispatcher == nil {
|
if d == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if (message.t == MessageTypeUniqueShuntAsync || message.t == MessageTypeUniqueAsync) && dispatcher.unique(message.name) {
|
if (message.t == MessageTypeUniqueShuntAsync || message.t == MessageTypeUniqueAsync) && d.Unique(message.name) {
|
||||||
srv.messagePool.Release(message)
|
srv.messagePool.Release(message)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
switch message.t {
|
||||||
|
case MessageTypeShuntAsync, MessageTypeUniqueShuntAsync:
|
||||||
|
d.IncrCount(message.conn.GetID(), 1)
|
||||||
|
}
|
||||||
srv.hitMessageStatistics()
|
srv.hitMessageStatistics()
|
||||||
dispatcher.put(message)
|
d.Put(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) low(message *Message, present time.Time, expect time.Duration, messageReplace ...string) {
|
func (srv *Server) low(message *Message, present time.Time, expect time.Duration, messageReplace ...string) {
|
||||||
|
@ -456,7 +368,7 @@ func (srv *Server) low(message *Message, present time.Time, expect time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// dispatchMessage 消息分发
|
// dispatchMessage 消息分发
|
||||||
func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) {
|
func (srv *Server) dispatchMessage(dispatcherIns *dispatcher.Dispatcher[string, *Message], msg *Message) {
|
||||||
var (
|
var (
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
@ -476,7 +388,7 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) {
|
||||||
|
|
||||||
present := time.Now()
|
present := time.Now()
|
||||||
if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync {
|
if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync {
|
||||||
defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher, msg *Message, present time.Time) {
|
defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher.Dispatcher[string, *Message], msg *Message, present time.Time) {
|
||||||
super.Handle(cancel)
|
super.Handle(cancel)
|
||||||
if err := super.RecoverTransform(recover()); err != nil {
|
if err := super.RecoverTransform(recover()); err != nil {
|
||||||
stack := string(debug.Stack())
|
stack := string(debug.Stack())
|
||||||
|
@ -485,7 +397,7 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) {
|
||||||
srv.OnMessageErrorEvent(msg, err)
|
srv.OnMessageErrorEvent(msg, err)
|
||||||
}
|
}
|
||||||
if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback {
|
if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback {
|
||||||
dispatcherIns.antiUnique(msg.name)
|
dispatcherIns.AntiUnique(msg.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.low(msg, present, time.Millisecond*100)
|
srv.low(msg, present, time.Millisecond*100)
|
||||||
|
@ -512,10 +424,14 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) {
|
||||||
msg.ordinaryHandler()
|
msg.ordinaryHandler()
|
||||||
case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync:
|
case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync:
|
||||||
if err := srv.ants.Submit(func() {
|
if err := srv.ants.Submit(func() {
|
||||||
defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher, msg *Message, present time.Time) {
|
defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher.Dispatcher[string, *Message], msg *Message, present time.Time) {
|
||||||
|
switch msg.t {
|
||||||
|
case MessageTypeShuntAsync, MessageTypeUniqueShuntAsync:
|
||||||
|
dispatcherIns.IncrCount(msg.conn.GetID(), -1)
|
||||||
|
}
|
||||||
if err := super.RecoverTransform(recover()); err != nil {
|
if err := super.RecoverTransform(recover()); err != nil {
|
||||||
if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync {
|
if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync {
|
||||||
dispatcherIns.antiUnique(msg.name)
|
dispatcherIns.AntiUnique(msg.name)
|
||||||
}
|
}
|
||||||
stack := string(debug.Stack())
|
stack := string(debug.Stack())
|
||||||
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack))
|
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack))
|
||||||
|
@ -550,7 +466,7 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) {
|
||||||
srv.PushShuntAsyncCallbackMessage(msg.conn, err, msg.errHandler)
|
srv.PushShuntAsyncCallbackMessage(msg.conn, err, msg.errHandler)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dispatcherIns.antiUnique(msg.name)
|
dispatcherIns.AntiUnique(msg.name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", string(debug.Stack())))
|
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", string(debug.Stack())))
|
||||||
}
|
}
|
||||||
|
@ -769,7 +685,8 @@ func onMessageSystemInit(srv *Server) {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
srv.startMessageStatistics()
|
srv.startMessageStatistics()
|
||||||
srv.systemDispatcher = generateDispatcher(serverSystemDispatcher, srv.dispatchMessage)
|
srv.dispatcherMgr = dispatcher.NewManager[string, *Message](srv.dispatcherBufferSize, srv.dispatchMessage).
|
||||||
go srv.systemDispatcher.start()
|
SetDispatcherCreatedHandler(srv.OnShuntChannelCreatedEvent).
|
||||||
|
SetDispatcherClosedHandler(srv.OnShuntChannelClosedEvent)
|
||||||
srv.OnMessageReadyEvent()
|
srv.OnMessageReadyEvent()
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ func TestNew(t *testing.T) {
|
||||||
fmt.Println("启动完成")
|
fmt.Println("启动完成")
|
||||||
})
|
})
|
||||||
srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) {
|
srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) {
|
||||||
fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount())
|
fmt.Println("关闭", conn.GetID(), err, "IncrCount", srv.GetOnlineCount())
|
||||||
})
|
})
|
||||||
|
|
||||||
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
|
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
|
||||||
|
@ -38,7 +38,7 @@ func TestNew2(t *testing.T) {
|
||||||
fmt.Println("启动完成")
|
fmt.Println("启动完成")
|
||||||
})
|
})
|
||||||
srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) {
|
srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) {
|
||||||
fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount())
|
fmt.Println("关闭", conn.GetID(), err, "IncrCount", srv.GetOnlineCount())
|
||||||
})
|
})
|
||||||
|
|
||||||
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
|
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
|
||||||
|
|
|
@ -1,15 +1,21 @@
|
||||||
package buffer
|
package buffer
|
||||||
|
|
||||||
// NewRing 创建一个环形缓冲区
|
// NewRing 创建一个并发不安全的环形缓冲区
|
||||||
func NewRing[T any](initSize int) *Ring[T] {
|
// - initSize: 初始容量
|
||||||
if initSize <= 1 {
|
//
|
||||||
panic("initial size must be great than one")
|
// 当初始容量小于 2 或未设置时,将会使用默认容量 2
|
||||||
|
func NewRing[T any](initSize ...int) *Ring[T] {
|
||||||
|
if len(initSize) == 0 {
|
||||||
|
initSize = append(initSize, 2)
|
||||||
|
}
|
||||||
|
if initSize[0] < 2 {
|
||||||
|
initSize[0] = 2
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Ring[T]{
|
return &Ring[T]{
|
||||||
buf: make([]T, initSize),
|
buf: make([]T, initSize[0]),
|
||||||
initSize: initSize,
|
initSize: initSize[0],
|
||||||
size: initSize,
|
size: initSize[0],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,91 +29,120 @@ type Ring[T any] struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read 读取数据
|
// Read 读取数据
|
||||||
func (slf *Ring[T]) Read() (T, error) {
|
func (b *Ring[T]) Read() (T, error) {
|
||||||
var t T
|
var t T
|
||||||
if slf.r == slf.w {
|
if b.r == b.w {
|
||||||
return t, ErrBufferIsEmpty
|
return t, ErrBufferIsEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
v := slf.buf[slf.r]
|
v := b.buf[b.r]
|
||||||
slf.r++
|
b.r++
|
||||||
if slf.r == slf.size {
|
if b.r == b.size {
|
||||||
slf.r = 0
|
b.r = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadAll 读取所有数据
|
||||||
|
func (b *Ring[T]) ReadAll() []T {
|
||||||
|
if b.r == b.w {
|
||||||
|
return nil // 没有数据时返回空切片
|
||||||
|
}
|
||||||
|
|
||||||
|
var length int
|
||||||
|
var data []T
|
||||||
|
|
||||||
|
if b.w > b.r {
|
||||||
|
length = b.w - b.r
|
||||||
|
} else {
|
||||||
|
length = len(b.buf) - b.r + b.w
|
||||||
|
}
|
||||||
|
data = make([]T, length) // 预分配空间
|
||||||
|
|
||||||
|
if b.w > b.r {
|
||||||
|
copy(data, b.buf[b.r:b.w])
|
||||||
|
} else {
|
||||||
|
copied := copy(data, b.buf[b.r:])
|
||||||
|
copy(data[copied:], b.buf[:b.w])
|
||||||
|
}
|
||||||
|
|
||||||
|
b.r = 0
|
||||||
|
b.w = 0
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
// Peek 查看数据
|
// Peek 查看数据
|
||||||
func (slf *Ring[T]) Peek() (t T, err error) {
|
func (b *Ring[T]) Peek() (t T, err error) {
|
||||||
if slf.r == slf.w {
|
if b.r == b.w {
|
||||||
return t, ErrBufferIsEmpty
|
return t, ErrBufferIsEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
return slf.buf[slf.r], nil
|
return b.buf[b.r], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write 写入数据
|
// Write 写入数据
|
||||||
func (slf *Ring[T]) Write(v T) {
|
func (b *Ring[T]) Write(v T) {
|
||||||
slf.buf[slf.w] = v
|
b.buf[b.w] = v
|
||||||
slf.w++
|
b.w++
|
||||||
|
|
||||||
if slf.w == slf.size {
|
if b.w == b.size {
|
||||||
slf.w = 0
|
b.w = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if slf.w == slf.r {
|
if b.w == b.r {
|
||||||
slf.grow()
|
b.grow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// grow 扩容
|
// grow 扩容
|
||||||
func (slf *Ring[T]) grow() {
|
func (b *Ring[T]) grow() {
|
||||||
var size int
|
var size int
|
||||||
if slf.size < 1024 {
|
if b.size < 1024 {
|
||||||
size = slf.size * 2
|
size = b.size * 2
|
||||||
} else {
|
} else {
|
||||||
size = slf.size + slf.size/4
|
size = b.size + b.size/4
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := make([]T, size)
|
buf := make([]T, size)
|
||||||
|
|
||||||
copy(buf[0:], slf.buf[slf.r:])
|
copy(buf[0:], b.buf[b.r:])
|
||||||
copy(buf[slf.size-slf.r:], slf.buf[0:slf.r])
|
copy(buf[b.size-b.r:], b.buf[0:b.r])
|
||||||
|
|
||||||
slf.r = 0
|
b.r = 0
|
||||||
slf.w = slf.size
|
b.w = b.size
|
||||||
slf.size = size
|
b.size = size
|
||||||
slf.buf = buf
|
b.buf = buf
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEmpty 是否为空
|
// IsEmpty 是否为空
|
||||||
func (slf *Ring[T]) IsEmpty() bool {
|
func (b *Ring[T]) IsEmpty() bool {
|
||||||
return slf.r == slf.w
|
return b.r == b.w
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cap 返回缓冲区容量
|
// Cap 返回缓冲区容量
|
||||||
func (slf *Ring[T]) Cap() int {
|
func (b *Ring[T]) Cap() int {
|
||||||
return slf.size
|
return b.size
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len 返回缓冲区长度
|
// Len 返回缓冲区长度
|
||||||
func (slf *Ring[T]) Len() int {
|
func (b *Ring[T]) Len() int {
|
||||||
if slf.r == slf.w {
|
if b.r == b.w {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if slf.w > slf.r {
|
if b.w > b.r {
|
||||||
return slf.w - slf.r
|
return b.w - b.r
|
||||||
}
|
}
|
||||||
|
|
||||||
return slf.size - slf.r + slf.w
|
return b.size - b.r + b.w
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset 重置缓冲区
|
// Reset 重置缓冲区
|
||||||
func (slf *Ring[T]) Reset() {
|
func (b *Ring[T]) Reset() {
|
||||||
slf.r = 0
|
b.r = 0
|
||||||
slf.w = 0
|
b.w = 0
|
||||||
slf.size = slf.initSize
|
b.size = b.initSize
|
||||||
slf.buf = make([]T, slf.initSize)
|
b.buf = make([]T, b.initSize)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
package buffer_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/kercylan98/minotaur/utils/buffer"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkRingWrite(b *testing.B) {
|
||||||
|
ring := buffer.NewRing[int](1024)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ring.Write(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRingRead(b *testing.B) {
|
||||||
|
ring := buffer.NewRing[int](1024)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ring.Write(i)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ring.Read()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
package buffer_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/kercylan98/minotaur/utils/buffer"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewRing(t *testing.T) {
|
||||||
|
ring := buffer.NewRing[int]()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
ring.Write(i)
|
||||||
|
t.Log(ring.Read())
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,100 @@
|
||||||
|
package buffer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewRingUnbounded 创建一个并发安全的基于环形缓冲区实现的无界缓冲区
|
||||||
|
func NewRingUnbounded[T any](bufferSize int) *RingUnbounded[T] {
|
||||||
|
ru := &RingUnbounded[T]{
|
||||||
|
ring: NewRing[T](1024),
|
||||||
|
rc: make(chan T, bufferSize),
|
||||||
|
closedSignal: make(chan struct{}),
|
||||||
|
}
|
||||||
|
ru.cond = sync.NewCond(&ru.rrm)
|
||||||
|
|
||||||
|
ru.process()
|
||||||
|
return ru
|
||||||
|
}
|
||||||
|
|
||||||
|
// RingUnbounded 基于环形缓冲区实现的无界缓冲区
|
||||||
|
type RingUnbounded[T any] struct {
|
||||||
|
ring *Ring[T]
|
||||||
|
rrm sync.Mutex
|
||||||
|
cond *sync.Cond
|
||||||
|
rc chan T
|
||||||
|
closed bool
|
||||||
|
closedMutex sync.RWMutex
|
||||||
|
closedSignal chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write 写入数据
|
||||||
|
func (b *RingUnbounded[T]) Write(v T) {
|
||||||
|
b.closedMutex.RLock()
|
||||||
|
defer b.closedMutex.RUnlock()
|
||||||
|
if b.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.rrm.Lock()
|
||||||
|
b.ring.Write(v)
|
||||||
|
b.cond.Signal()
|
||||||
|
b.rrm.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read 读取数据
|
||||||
|
func (b *RingUnbounded[T]) Read() <-chan T {
|
||||||
|
return b.rc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Closed 判断缓冲区是否已关闭
|
||||||
|
func (b *RingUnbounded[T]) Closed() bool {
|
||||||
|
b.closedMutex.RLock()
|
||||||
|
defer b.closedMutex.RUnlock()
|
||||||
|
return b.closed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭缓冲区,关闭后将不再接收新数据,但是已有数据仍然可以读取
|
||||||
|
func (b *RingUnbounded[T]) Close() <-chan struct{} {
|
||||||
|
b.closedMutex.Lock()
|
||||||
|
defer b.closedMutex.Unlock()
|
||||||
|
if b.closed {
|
||||||
|
return b.closedSignal
|
||||||
|
}
|
||||||
|
b.closed = true
|
||||||
|
|
||||||
|
b.rrm.Lock()
|
||||||
|
b.cond.Signal()
|
||||||
|
b.rrm.Unlock()
|
||||||
|
return b.closedSignal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *RingUnbounded[T]) process() {
|
||||||
|
go func(b *RingUnbounded[T]) {
|
||||||
|
for {
|
||||||
|
b.closedMutex.RLock()
|
||||||
|
b.rrm.Lock()
|
||||||
|
vs := b.ring.ReadAll()
|
||||||
|
if len(vs) == 0 && !b.closed {
|
||||||
|
b.closedMutex.RUnlock()
|
||||||
|
b.cond.Wait()
|
||||||
|
} else {
|
||||||
|
b.closedMutex.RUnlock()
|
||||||
|
}
|
||||||
|
b.rrm.Unlock()
|
||||||
|
|
||||||
|
b.closedMutex.RLock()
|
||||||
|
if b.closed && len(vs) == 0 {
|
||||||
|
close(b.rc)
|
||||||
|
close(b.closedSignal)
|
||||||
|
b.closedMutex.RUnlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range vs {
|
||||||
|
b.rc <- v
|
||||||
|
}
|
||||||
|
b.closedMutex.RUnlock()
|
||||||
|
}
|
||||||
|
}(b)
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
package buffer_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/kercylan98/minotaur/utils/buffer"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkRingUnbounded_Write(b *testing.B) {
|
||||||
|
ring := buffer.NewRingUnbounded[int](1024 * 16)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ring.Write(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRingUnbounded_Read(b *testing.B) {
|
||||||
|
ring := buffer.NewRingUnbounded[int](1024 * 16)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ring.Write(i)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
<-ring.Read()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRingUnbounded_Write_Parallel(b *testing.B) {
|
||||||
|
ring := buffer.NewRingUnbounded[int](1024 * 16)
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
ring.Write(1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRingUnbounded_Read_Parallel(b *testing.B) {
|
||||||
|
ring := buffer.NewRingUnbounded[int](1024 * 16)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ring.Write(i)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
<-ring.Read()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package buffer_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/kercylan98/minotaur/utils/buffer"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRingUnbounded_Write2Read(t *testing.T) {
|
||||||
|
ring := buffer.NewRingUnbounded[int](1024 * 16)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
ring.Write(i)
|
||||||
|
}
|
||||||
|
t.Log("write done")
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
t.Log(<-ring.Read())
|
||||||
|
}
|
||||||
|
t.Log("read done")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRingUnbounded_Close(t *testing.T) {
|
||||||
|
ring := buffer.NewRingUnbounded[int](1024 * 16)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
ring.Write(i)
|
||||||
|
}
|
||||||
|
t.Log("write done")
|
||||||
|
ring.Close()
|
||||||
|
t.Log("close done")
|
||||||
|
for v := range ring.Read() {
|
||||||
|
ring.Write(v)
|
||||||
|
t.Log(v)
|
||||||
|
}
|
||||||
|
t.Log("read done")
|
||||||
|
}
|
|
@ -5,15 +5,46 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkUnboundedBuffer(b *testing.B) {
|
func BenchmarkUnbounded_Write(b *testing.B) {
|
||||||
ub := buffer.NewUnbounded[int]()
|
u := buffer.NewUnbounded[int]()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
u.Put(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUnbounded_Read(b *testing.B) {
|
||||||
|
u := buffer.NewUnbounded[int]()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
u.Put(i)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
<-u.Get()
|
||||||
|
u.Load()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUnbounded_Write_Parallel(b *testing.B) {
|
||||||
|
u := buffer.NewUnbounded[int]()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
ub.Put(1)
|
u.Put(1)
|
||||||
<-ub.Get()
|
}
|
||||||
ub.Load()
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUnbounded_Read_Parallel(b *testing.B) {
|
||||||
|
u := buffer.NewUnbounded[int]()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
u.Put(i)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
<-u.Get()
|
||||||
|
u.Load()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue