refactor: 重构 server 包分流渠道设计,修复部分问题

- 使用 RingBuffer 实现分流渠道的无界缓冲区,修复分流渠道被关闭后,未处理完成的消息将会被丢弃的问题;
- 移除 server.WithDisableAutomaticReleaseShunt 可选项,分流渠道将在消息处理完毕且没有连接使用时自行释放;
This commit is contained in:
kercylan98 2024-01-08 19:10:12 +08:00
parent 3402c83fd4
commit 3408c212d0
21 changed files with 863 additions and 325 deletions

View File

@ -366,7 +366,4 @@ func (slf *Conn) Close(err ...error) {
return
}
slf.server.OnConnectionClosedEvent(slf, nil)
if !slf.server.runtime.disableAutomaticReleaseShunt {
slf.server.releaseDispatcher(slf)
}
}

View File

@ -7,9 +7,8 @@ import (
)
const (
serverMultipleMark = "Minotaur Multiple Server"
serverMark = "Minotaur Server"
serverSystemDispatcher = "__system" // 系统消息分发器
serverMultipleMark = "Minotaur Multiple Server"
serverMark = "Minotaur Server"
)
const (

View File

@ -1,83 +0,0 @@
package server
import (
"github.com/alphadose/haxmap"
"github.com/kercylan98/minotaur/utils/buffer"
"sync/atomic"
"time"
)
var dispatcherUnique = struct{}{}
// generateDispatcher 生成消息分发器
func generateDispatcher(name string, handler func(dispatcher *dispatcher, message *Message)) *dispatcher {
d := &dispatcher{
name: name,
buf: buffer.NewUnbounded[*Message](),
handler: handler,
uniques: haxmap.New[string, struct{}](),
}
return d
}
// dispatcher 消息分发器
type dispatcher struct {
name string
buf *buffer.Unbounded[*Message]
uniques *haxmap.Map[string, struct{}]
handler func(dispatcher *dispatcher, message *Message)
closed uint32
msgCount int64
}
func (d *dispatcher) unique(name string) bool {
_, loaded := d.uniques.GetOrSet(name, dispatcherUnique)
return loaded
}
func (d *dispatcher) antiUnique(name string) {
d.uniques.Del(name)
}
func (d *dispatcher) start() {
defer d.buf.Close()
for {
select {
case message, ok := <-d.buf.Get():
if !ok {
return
}
d.buf.Load()
d.handler(d, message)
if atomic.AddInt64(&d.msgCount, -1) <= 0 && atomic.LoadUint32(&d.closed) == 1 {
return
}
}
}
}
func (d *dispatcher) put(message *Message) {
if atomic.CompareAndSwapUint32(&d.closed, 1, 1) {
return
}
atomic.AddInt64(&d.msgCount, 1)
d.buf.Put(message)
}
func (d *dispatcher) close() {
atomic.CompareAndSwapUint32(&d.closed, 0, 1)
go func() {
for {
if d.buf.IsClosed() {
return
}
if atomic.LoadInt64(&d.msgCount) <= 0 {
d.buf.Close()
return
}
time.Sleep(time.Second)
}
}()
}

View File

@ -216,6 +216,7 @@ func (slf *event) OnConnectionClosedEvent(conn *Conn, err any) {
value(slf.Server, conn, err)
return true
})
slf.Server.dispatcherMgr.UnBindProducer(conn.GetID())
}, log.String("Event", "OnConnectionClosedEvent"))
}

View File

@ -0,0 +1,153 @@
package dispatcher
import (
"github.com/alphadose/haxmap"
"github.com/kercylan98/minotaur/utils/buffer"
"sync"
"sync/atomic"
)
var unique = struct{}{}
// Handler 消息处理器
type Handler[P Producer, M Message[P]] func(dispatcher *Dispatcher[P, M], message M)
// NewDispatcher 生成消息分发器
func NewDispatcher[P Producer, M Message[P]](bufferSize int, name string, handler Handler[P, M]) *Dispatcher[P, M] {
d := &Dispatcher[P, M]{
name: name,
buf: buffer.NewRingUnbounded[M](bufferSize),
handler: handler,
uniques: haxmap.New[string, struct{}](),
pmc: make(map[P]int64),
pmcF: make(map[P]func(p P, dispatcher *Dispatcher[P, M])),
abort: make(chan struct{}),
}
return d
}
// Dispatcher 消息分发器
type Dispatcher[P Producer, M Message[P]] struct {
buf *buffer.RingUnbounded[M]
uniques *haxmap.Map[string, struct{}]
handler Handler[P, M]
expel bool
mc int64
pmc map[P]int64
pmcF map[P]func(p P, dispatcher *Dispatcher[P, M])
lock sync.RWMutex
name string
closedHandler atomic.Pointer[func(dispatcher *Dispatcher[P, M])]
abort chan struct{}
}
// SetProducerDoneHandler 设置特定生产者所有消息处理完成时的回调函数
func (d *Dispatcher[P, M]) SetProducerDoneHandler(p P, handler func(p P, dispatcher *Dispatcher[P, M])) *Dispatcher[P, M] {
d.lock.Lock()
if handler == nil {
delete(d.pmcF, p)
} else {
d.pmcF[p] = handler
}
d.lock.Unlock()
return d
}
// SetClosedHandler 设置消息分发器关闭时的回调函数
func (d *Dispatcher[P, M]) SetClosedHandler(handler func(dispatcher *Dispatcher[P, M])) *Dispatcher[P, M] {
d.closedHandler.Store(&handler)
return d
}
// Name 获取消息分发器名称
func (d *Dispatcher[P, M]) Name() string {
return d.name
}
// Unique 设置唯一消息键,返回是否已存在
func (d *Dispatcher[P, M]) Unique(name string) bool {
_, loaded := d.uniques.GetOrSet(name, unique)
return loaded
}
// AntiUnique 取消唯一消息键
func (d *Dispatcher[P, M]) AntiUnique(name string) {
d.uniques.Del(name)
}
// Expel 设置该消息分发器即将被驱逐,当消息分发器中没有任何消息时,会自动关闭
func (d *Dispatcher[P, M]) Expel() {
d.lock.Lock()
d.expel = true
if d.mc <= 0 {
d.abort <- struct{}{}
}
d.lock.Unlock()
}
// UnExpel 取消特定生产者的驱逐计划
func (d *Dispatcher[P, M]) UnExpel() {
d.lock.Lock()
d.expel = false
d.lock.Unlock()
}
// IncrCount 主动增量设置特定生产者的消息计数,这在等待异步消息完成后再关闭消息分发器时非常有用
func (d *Dispatcher[P, M]) IncrCount(producer P, i int64) {
d.lock.Lock()
d.mc += i
d.pmc[producer] += i
if d.expel && d.mc <= 0 {
d.abort <- struct{}{}
}
d.lock.Unlock()
}
// Put 将消息放入分发器
func (d *Dispatcher[P, M]) Put(message M) {
d.lock.Lock()
d.mc++
d.pmc[message.GetProducer()]++
d.lock.Unlock()
d.buf.Write(message)
}
// Start 以阻塞的方式开始进行消息分发,当消息分发器中没有任何消息时,会自动关闭
func (d *Dispatcher[P, M]) Start() *Dispatcher[P, M] {
go func(d *Dispatcher[P, M]) {
process:
for {
select {
case <-d.abort:
d.buf.Close()
break process
case message := <-d.buf.Read():
d.handler(d, message)
d.lock.Lock()
d.mc--
p := message.GetProducer()
pmc := d.pmc[p] - 1
d.pmc[p] = pmc
if f := d.pmcF[p]; f != nil && pmc <= 0 {
go f(p, d)
}
if d.mc <= 0 && d.expel {
d.buf.Close()
break process
}
d.lock.Unlock()
}
}
closedHandler := *(d.closedHandler.Load())
if closedHandler != nil {
closedHandler(d)
}
close(d.abort)
}(d)
return d
}
// Closed 判断消息分发器是否已关闭
func (d *Dispatcher[P, M]) Closed() bool {
return d.buf.Closed()
}

View File

@ -0,0 +1,48 @@
package dispatcher_test
import (
"github.com/kercylan98/minotaur/server/internal/dispatcher"
"sync"
"testing"
"time"
)
type TestMessage struct {
producer string
v int
}
func (m *TestMessage) GetProducer() string {
return m.producer
}
func TestDispatcher_PutStartClose(t *testing.T) {
// 写入完成后,关闭分发器再开始分发,确保消息不会丢失
w := new(sync.WaitGroup)
cw := new(sync.WaitGroup)
cw.Add(1)
d := dispatcher.NewDispatcher[string, *TestMessage](1024*16, "test", func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
t.Log(message)
w.Done()
}).SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
t.Log("closed")
cw.Done()
})
for i := 0; i < 100; i++ {
w.Add(1)
d.Put(&TestMessage{
producer: "test",
v: i,
})
}
d.Start()
d.Expel()
d.UnExpel()
w.Wait()
time.Sleep(time.Second)
d.Expel()
cw.Wait()
t.Log("done")
}

View File

@ -0,0 +1,153 @@
package dispatcher
import (
"sync"
)
const SystemName = "*system"
// NewManager 生成消息分发器管理器
func NewManager[P Producer, M Message[P]](bufferSize int, handler Handler[P, M]) *Manager[P, M] {
mgr := &Manager[P, M]{
handler: handler,
dispatchers: make(map[string]*Dispatcher[P, M]),
member: make(map[string]map[P]struct{}),
sys: NewDispatcher(bufferSize, SystemName, handler).Start(),
curr: make(map[P]*Dispatcher[P, M]),
size: bufferSize,
}
return mgr
}
// Manager 消息分发器管理器
type Manager[P Producer, M Message[P]] struct {
handler Handler[P, M] // 消息处理器
sys *Dispatcher[P, M] // 系统消息分发器
dispatchers map[string]*Dispatcher[P, M] // 当前所有正在工作的消息分发器
member map[string]map[P]struct{} // 当前正在工作的消息分发器对应的生产者
curr map[P]*Dispatcher[P, M] // 当前特定生产者正在使用的消息分发器
lock sync.RWMutex // 消息分发器锁
w sync.WaitGroup // 消息分发器等待组
size int // 消息分发器缓冲区大小
closedHandler func(name string)
createdHandler func(name string)
}
// SetDispatcherClosedHandler 设置消息分发器关闭时的回调函数
func (m *Manager[P, M]) SetDispatcherClosedHandler(handler func(name string)) *Manager[P, M] {
m.closedHandler = handler
return m
}
// SetDispatcherCreatedHandler 设置消息分发器创建时的回调函数
func (m *Manager[P, M]) SetDispatcherCreatedHandler(handler func(name string)) *Manager[P, M] {
m.createdHandler = handler
return m
}
// HasDispatcher 检查是否存在指定名称的消息分发器
func (m *Manager[P, M]) HasDispatcher(name string) bool {
m.lock.RLock()
defer m.lock.RUnlock()
_, exist := m.dispatchers[name]
return exist
}
// GetDispatcherNum 获取当前正在工作的消息分发器数量
func (m *Manager[P, M]) GetDispatcherNum() int {
m.lock.RLock()
defer m.lock.RUnlock()
return len(m.dispatchers) + 1 // +1 系统消息分发器
}
// GetSystemDispatcher 获取系统消息分发器
func (m *Manager[P, M]) GetSystemDispatcher() *Dispatcher[P, M] {
return m.sys
}
// GetDispatcher 获取生产者正在使用的消息分发器,如果生产者没有绑定消息分发器,则会返回系统消息分发器
func (m *Manager[P, M]) GetDispatcher(p P) *Dispatcher[P, M] {
m.lock.Lock()
defer m.lock.Unlock()
curr, exist := m.curr[p]
if exist {
return curr
}
return m.sys
}
// BindProducer 绑定生产者使用特定的消息分发器,如果生产者已经绑定了消息分发器,则会先解绑
func (m *Manager[P, M]) BindProducer(p P, name string) {
if name == SystemName {
return
}
m.lock.Lock()
defer m.lock.Unlock()
member, exist := m.member[name]
if !exist {
member = make(map[P]struct{})
m.member[name] = member
}
if _, exist = member[p]; exist {
d := m.dispatchers[name]
d.SetProducerDoneHandler(p, nil)
d.UnExpel()
return
}
curr, exist := m.curr[p]
if exist {
delete(m.member[curr.name], p)
if len(m.member[curr.name]) == 0 {
curr.Expel()
}
}
dispatcher, exist := m.dispatchers[name]
if !exist {
dispatcher = NewDispatcher(m.size, name, m.handler).SetClosedHandler(func(dispatcher *Dispatcher[P, M]) {
// 消息分发器关闭时,将会将其从管理器中移除
m.lock.Lock()
delete(m.dispatchers, dispatcher.name)
delete(m.member, dispatcher.name)
m.lock.Unlock()
if m.closedHandler != nil {
m.closedHandler(dispatcher.name)
}
}).Start()
m.dispatchers[name] = dispatcher
defer func(m *Manager[P, M], name string) {
if m.createdHandler != nil {
m.createdHandler(name)
}
}(m, dispatcher.Name())
}
m.curr[p] = dispatcher
member[p] = struct{}{}
}
// UnBindProducer 解绑生产者使用特定的消息分发器
func (m *Manager[P, M]) UnBindProducer(p P) {
m.lock.Lock()
defer m.lock.Unlock()
curr, exist := m.curr[p]
if !exist {
return
}
curr.SetProducerDoneHandler(p, func(p P, dispatcher *Dispatcher[P, M]) {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.member[dispatcher.name], p)
delete(m.curr, p)
if len(m.member[dispatcher.name]) == 0 {
dispatcher.Expel()
}
})
}

View File

@ -0,0 +1,46 @@
package dispatcher_test
import (
"github.com/kercylan98/minotaur/server/internal/dispatcher"
"github.com/kercylan98/minotaur/utils/times"
"testing"
"time"
)
func TestManager(t *testing.T) {
var mgr *dispatcher.Manager[string, *TestMessage]
var onHandler = func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
t.Log(dispatcher.Name(), message, mgr.GetDispatcherNum())
switch message.v {
case 4:
mgr.UnBindProducer("test")
t.Log("UnBindProducer")
case 6:
mgr.BindProducer(message.GetProducer(), "test-dispatcher")
t.Log("BindProducer")
case 9:
dispatcher.Put(&TestMessage{
producer: "test",
v: 10,
})
case 10:
mgr.UnBindProducer("test")
t.Log("UnBindProducer", mgr.GetDispatcherNum())
}
}
mgr = dispatcher.NewManager[string, *TestMessage](1024*16, onHandler)
mgr.BindProducer("test", "test-dispatcher")
for i := 0; i < 10; i++ {
d := mgr.GetDispatcher("test").SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
t.Log("closed")
})
d.Put(&TestMessage{
producer: "test",
v: i,
})
}
time.Sleep(times.Day)
}

View File

@ -0,0 +1,6 @@
package dispatcher
type Message[P comparable] interface {
// GetProducer 获取消息生产者
GetProducer() P
}

View File

@ -0,0 +1,5 @@
package dispatcher
type Producer interface {
comparable
}

View File

@ -75,6 +75,7 @@ func HasMessageType(mt MessageType) bool {
// Message 服务器消息
type Message struct {
producer string
conn *Conn
ordinaryHandler func()
exceptionHandler func() error
@ -86,6 +87,10 @@ type Message struct {
marks []log.Field
}
func (slf *Message) GetProducer() string {
return slf.producer
}
// reset 重置消息结构体
func (slf *Message) reset() {
slf.conn = nil
@ -126,78 +131,91 @@ func (slf MessageType) String() string {
// castToPacketMessage 将消息转换为数据包消息
func (slf *Message) castToPacketMessage(conn *Conn, packet []byte, mark ...log.Field) *Message {
slf.producer = conn.GetID()
slf.t, slf.conn, slf.packet, slf.marks = MessageTypePacket, conn, packet, mark
return slf
}
// castToTickerMessage 将消息转换为定时器消息
func (slf *Message) castToTickerMessage(name string, caller func(), mark ...log.Field) *Message {
slf.producer = "sys"
slf.t, slf.name, slf.ordinaryHandler, slf.marks = MessageTypeTicker, name, caller, mark
return slf
}
// castToShuntTickerMessage 将消息转换为分发器定时器消息
func (slf *Message) castToShuntTickerMessage(conn *Conn, name string, caller func(), mark ...log.Field) *Message {
slf.producer = conn.GetID()
slf.t, slf.conn, slf.name, slf.ordinaryHandler, slf.marks = MessageTypeShuntTicker, conn, name, caller, mark
return slf
}
// castToAsyncMessage 将消息转换为异步消息
func (slf *Message) castToAsyncMessage(caller func() error, callback func(err error), mark ...log.Field) *Message {
slf.producer = "sys"
slf.t, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeAsync, caller, callback, mark
return slf
}
// castToAsyncCallbackMessage 将消息转换为异步回调消息
func (slf *Message) castToAsyncCallbackMessage(err error, caller func(err error), mark ...log.Field) *Message {
slf.producer = "sys"
slf.t, slf.err, slf.errHandler, slf.marks = MessageTypeAsyncCallback, err, caller, mark
return slf
}
// castToShuntAsyncMessage 将消息转换为分流异步消息
func (slf *Message) castToShuntAsyncMessage(conn *Conn, caller func() error, callback func(err error), mark ...log.Field) *Message {
slf.producer = conn.GetID()
slf.t, slf.conn, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeShuntAsync, conn, caller, callback, mark
return slf
}
// castToShuntAsyncCallbackMessage 将消息转换为分流异步回调消息
func (slf *Message) castToShuntAsyncCallbackMessage(conn *Conn, err error, caller func(err error), mark ...log.Field) *Message {
slf.producer = conn.GetID()
slf.t, slf.conn, slf.err, slf.errHandler, slf.marks = MessageTypeShuntAsyncCallback, conn, err, caller, mark
return slf
}
// castToUniqueAsyncMessage 将消息转换为唯一异步消息
func (slf *Message) castToUniqueAsyncMessage(unique string, caller func() error, callback func(err error), mark ...log.Field) *Message {
slf.producer = "sys"
slf.t, slf.name, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeUniqueAsync, unique, caller, callback, mark
return slf
}
// castToUniqueAsyncCallbackMessage 将消息转换为唯一异步回调消息
func (slf *Message) castToUniqueAsyncCallbackMessage(unique string, err error, caller func(err error), mark ...log.Field) *Message {
slf.producer = "sys"
slf.t, slf.name, slf.err, slf.errHandler, slf.marks = MessageTypeUniqueAsyncCallback, unique, err, caller, mark
return slf
}
// castToUniqueShuntAsyncMessage 将消息转换为唯一分流异步消息
func (slf *Message) castToUniqueShuntAsyncMessage(conn *Conn, unique string, caller func() error, callback func(err error), mark ...log.Field) *Message {
slf.producer = conn.GetID()
slf.t, slf.conn, slf.name, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeUniqueShuntAsync, conn, unique, caller, callback, mark
return slf
}
// castToUniqueShuntAsyncCallbackMessage 将消息转换为唯一分流异步回调消息
func (slf *Message) castToUniqueShuntAsyncCallbackMessage(conn *Conn, unique string, err error, caller func(err error), mark ...log.Field) *Message {
slf.producer = conn.GetID()
slf.t, slf.conn, slf.name, slf.err, slf.errHandler, slf.marks = MessageTypeUniqueShuntAsyncCallback, conn, unique, err, caller, mark
return slf
}
// castToSystemMessage 将消息转换为系统消息
func (slf *Message) castToSystemMessage(caller func(), mark ...log.Field) *Message {
slf.producer = "sys"
slf.t, slf.ordinaryHandler, slf.marks = MessageTypeSystem, caller, mark
return slf
}
// castToShuntMessage 将消息转换为分流消息
func (slf *Message) castToShuntMessage(conn *Conn, caller func(), mark ...log.Field) *Message {
slf.producer = conn.GetID()
slf.t, slf.conn, slf.ordinaryHandler, slf.marks = MessageTypeShunt, conn, caller, mark
return slf
}

View File

@ -32,26 +32,26 @@ type option struct {
}
type runtime struct {
deadlockDetect time.Duration // 是否开启死锁检测
supportMessageTypes map[int]bool // websocket 模式下支持的消息类型
certFile, keyFile string // TLS文件
tickerPool *timer.Pool // 定时器池
ticker *timer.Ticker // 定时器
tickerAutonomy bool // 定时器是否独立运行
connTickerSize int // 连接定时器大小
websocketReadDeadline time.Duration // websocket 连接超时时间
websocketCompression int // websocket 压缩等级
websocketWriteCompression bool // websocket 写入压缩
limitLife time.Duration // 限制最大生命周期
packetWarnSize int // 数据包大小警告
messageStatisticsDuration time.Duration // 消息统计时长
messageStatisticsLimit int // 消息统计数量
messageStatistics []*atomic.Int64 // 消息统计数量
messageStatisticsLock *sync.RWMutex // 消息统计锁
connWriteBufferSize int // 连接写入缓冲区大小
disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道
websocketUpgrader *websocket.Upgrader // websocket 升级器
websocketConnInitializer func(writer http.ResponseWriter, request *http.Request, conn *websocket.Conn) error // websocket 连接初始化
deadlockDetect time.Duration // 是否开启死锁检测
supportMessageTypes map[int]bool // websocket 模式下支持的消息类型
certFile, keyFile string // TLS文件
tickerPool *timer.Pool // 定时器池
ticker *timer.Ticker // 定时器
tickerAutonomy bool // 定时器是否独立运行
connTickerSize int // 连接定时器大小
websocketReadDeadline time.Duration // websocket 连接超时时间
websocketCompression int // websocket 压缩等级
websocketWriteCompression bool // websocket 写入压缩
limitLife time.Duration // 限制最大生命周期
packetWarnSize int // 数据包大小警告
messageStatisticsDuration time.Duration // 消息统计时长
messageStatisticsLimit int // 消息统计数量
messageStatistics []*atomic.Int64 // 消息统计数量
messageStatisticsLock *sync.RWMutex // 消息统计锁
connWriteBufferSize int // 连接写入缓冲区大小
websocketUpgrader *websocket.Upgrader // websocket 升级器
websocketConnInitializer func(writer http.ResponseWriter, request *http.Request, conn *websocket.Conn) error // websocket 连接初始化
dispatcherBufferSize int // 消息分发器缓冲区大小
}
// WithWebsocketConnInitializer 通过 websocket 连接初始化的方式创建服务器,当 initializer 返回错误时,服务器将不会处理该连接的后续逻辑
@ -77,14 +77,6 @@ func WithWebsocketUpgrade(upgrader *websocket.Upgrader) Option {
}
}
// WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器
// - 默认不开启,当禁用自动释放分流渠道时,服务器将不会在连接断开时自动释放分流渠道,需要手动调用 ReleaseShunt 方法释放
func WithDisableAutomaticReleaseShunt() Option {
return func(srv *Server) {
srv.runtime.disableAutomaticReleaseShunt = true
}
}
// WithConnWriteBufferSize 通过连接写入缓冲区大小的方式创建服务器
// - 默认值为 DefaultConnWriteBufferSize
// - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存
@ -100,14 +92,14 @@ func WithConnWriteBufferSize(size int) Option {
// WithDispatcherBufferSize 通过消息分发器缓冲区大小的方式创建服务器
// - 默认值为 DefaultDispatcherBufferSize
// - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存
//func WithDispatcherBufferSize(size int) Option {
// return func(srv *Server) {
// if size <= 0 {
// return
// }
// srv.dispatcherBufferSize = size
// }
//}
func WithDispatcherBufferSize(size int) Option {
return func(srv *Server) {
if size <= 0 {
return
}
srv.dispatcherBufferSize = size
}
}
// WithMessageStatistics 通过消息统计的方式创建服务器
// - 默认不开启,当 duration 和 limit 均大于 0 的时候,服务器将记录每 duration 期间的消息数量,并保留最多 limit 条

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/kercylan98/minotaur/server/internal/dispatcher"
"github.com/kercylan98/minotaur/server/internal/logger"
"github.com/kercylan98/minotaur/utils/concurrent"
"github.com/kercylan98/minotaur/utils/log"
@ -21,7 +22,6 @@ import (
"os"
"os/signal"
"runtime/debug"
"sync"
"sync/atomic"
"syscall"
"time"
@ -32,17 +32,15 @@ func New(network Network, options ...Option) *Server {
network.check()
server := &Server{
runtime: &runtime{
packetWarnSize: DefaultPacketWarnSize,
connWriteBufferSize: DefaultConnWriteBufferSize,
packetWarnSize: DefaultPacketWarnSize,
connWriteBufferSize: DefaultConnWriteBufferSize,
dispatcherBufferSize: DefaultDispatcherBufferSize,
},
hub: &hub{},
option: &option{},
network: network,
closeChannel: make(chan struct{}, 1),
systemSignal: make(chan os.Signal, 1),
dispatchers: make(map[string]*dispatcher),
dispatcherMember: map[string]map[string]*Conn{},
currDispatcher: map[string]*dispatcher{},
hub: &hub{},
option: &option{},
network: network,
closeChannel: make(chan struct{}, 1),
systemSignal: make(chan os.Signal, 1),
}
server.ctx, server.cancel = context.WithCancel(context.Background())
server.event = newEvent(server)
@ -67,32 +65,29 @@ func New(network Network, options ...Option) *Server {
// Server 网络服务器
type Server struct {
*event // 事件
*runtime // 运行时
*option // 可选项
*hub // 连接集合
ginServer *gin.Engine // HTTP模式下的路由器
httpServer *http.Server // HTTP模式下的服务器
grpcServer *grpc.Server // GRPC模式下的服务器
gServer *gNet // TCP或UDP模式下的服务器
multiple *MultipleServer // 多服务器模式下的服务器
ants *ants.Pool // 协程池
messagePool *concurrent.Pool[*Message] // 消息池
ctx context.Context // 上下文
cancel context.CancelFunc // 停止上下文
systemDispatcher *dispatcher // 系统消息分发器
systemSignal chan os.Signal // 系统信号
closeChannel chan struct{} // 关闭信号
multipleRuntimeErrorChan chan error // 多服务器模式下的运行时错误
dispatchers map[string]*dispatcher // 消息分发器集合
dispatcherMember map[string]map[string]*Conn // 消息分发器包含的连接
currDispatcher map[string]*dispatcher // 当前连接所处消息分发器
dispatcherLock sync.RWMutex // 消息分发器锁
messageCounter atomic.Int64 // 消息计数器
addr string // 侦听地址
network Network // 网络类型
closed uint32 // 服务器是否已关闭
services []func() // 服务
*event // 事件
*runtime // 运行时
*option // 可选项
*hub // 连接集合
dispatcherMgr *dispatcher.Manager[string, *Message] // 消息分发器管理器
ginServer *gin.Engine // HTTP模式下的路由器
httpServer *http.Server // HTTP模式下的服务器
grpcServer *grpc.Server // GRPC模式下的服务器
gServer *gNet // TCP或UDP模式下的服务器
multiple *MultipleServer // 多服务器模式下的服务器
ants *ants.Pool // 协程池
messagePool *concurrent.Pool[*Message] // 消息池
ctx context.Context // 上下文
cancel context.CancelFunc // 停止上下文
systemSignal chan os.Signal // 系统信号
closeChannel chan struct{} // 关闭信号
multipleRuntimeErrorChan chan error // 多服务器模式下的运行时错误
messageCounter atomic.Int64 // 消息计数器
addr string // 侦听地址
network Network // 网络类型
closed uint32 // 服务器是否已关闭
services []func() // 服务
}
// preCheckAndAdaptation 预检查及适配
@ -221,13 +216,6 @@ func (srv *Server) shutdown(err error) {
srv.ants.Release()
srv.ants = nil
}
srv.dispatcherLock.Lock()
for s, d := range srv.dispatchers {
srv.OnShuntChannelClosedEvent(d.name)
d.close()
delete(srv.dispatchers, s)
}
srv.dispatcherLock.Unlock()
if srv.grpcServer != nil {
srv.grpcServer.GracefulStop()
}
@ -300,107 +288,27 @@ func (srv *Server) GetMessageCount() int64 {
// UseShunt 切换连接所使用的消息分流渠道,当分流渠道 name 不存在时将会创建一个新的分流渠道,否则将会加入已存在的分流渠道
// - 默认情况下,所有连接都使用系统通道进行消息分发,当指定消息分流渠道时,将会使用指定的消息分流渠道进行消息分发
// - 在使用 WithDisableAutomaticReleaseShunt 创建服务器后,必须始终在连接不再使用后主动通过 ReleaseShunt 释放消息分流渠道,否则将造成内存泄漏
// - 分流渠道会在连接断开时标记为驱逐状态,当分流渠道中的所有消息处理完毕且没有新连接使用时,将会被清除
//
// 一些有趣的情况:
// - 当连接发送异步消息时,消息会被分为两部分,分别是异步部分和回调部分。异步部分会在当前的分流渠道中处理,而回调部分则是根据回调时所在的分流渠道进行处理
func (srv *Server) UseShunt(conn *Conn, name string) {
srv.dispatcherLock.Lock()
defer srv.dispatcherLock.Unlock()
d, exist := srv.dispatchers[name]
if !exist {
d = generateDispatcher(name, srv.dispatchMessage)
srv.OnShuntChannelCreatedEvent(d.name)
go d.start()
srv.dispatchers[name] = d
}
curr, exist := srv.currDispatcher[conn.GetID()]
if exist {
if curr.name == name {
return
}
delete(srv.dispatcherMember[curr.name], conn.GetID())
if curr.name != serverSystemDispatcher && len(srv.dispatcherMember[curr.name]) == 0 {
delete(srv.dispatchers, curr.name)
srv.OnShuntChannelClosedEvent(d.name)
curr.close()
}
}
srv.currDispatcher[conn.GetID()] = d
member, exist := srv.dispatcherMember[name]
if !exist {
member = map[string]*Conn{}
srv.dispatcherMember[name] = member
}
member[conn.GetID()] = conn
srv.dispatcherMgr.BindProducer(conn.GetID(), name)
}
// HasShunt 检查特定消息分流渠道是否存在
func (srv *Server) HasShunt(name string) bool {
srv.dispatcherLock.RLock()
defer srv.dispatcherLock.RUnlock()
_, exist := srv.dispatchers[name]
return exist
return srv.dispatcherMgr.HasDispatcher(name)
}
// GetConnCurrShunt 获取连接当前所使用的消息分流渠道
func (srv *Server) GetConnCurrShunt(conn *Conn) string {
srv.dispatcherLock.RLock()
defer srv.dispatcherLock.RUnlock()
d, exist := srv.currDispatcher[conn.GetID()]
if exist {
return d.name
}
return serverSystemDispatcher
return srv.dispatcherMgr.GetDispatcher(conn.GetID()).Name()
}
// GetShuntNum 获取消息分流渠道数量
func (srv *Server) GetShuntNum() int {
srv.dispatcherLock.RLock()
defer srv.dispatcherLock.RUnlock()
return len(srv.dispatchers)
}
// getConnDispatcher 获取连接所使用的消息分发器
func (srv *Server) getConnDispatcher(conn *Conn) *dispatcher {
if conn == nil {
return srv.systemDispatcher
}
srv.dispatcherLock.RLock()
defer srv.dispatcherLock.RUnlock()
d, exist := srv.currDispatcher[conn.GetID()]
if exist {
return d
}
return srv.systemDispatcher
}
// ReleaseShunt 释放分流渠道中的连接,当分流渠道中不再存在连接时将会自动释放分流渠道
// - 在未使用 WithDisableAutomaticReleaseShunt 选项时,当连接关闭时将会自动释放分流渠道中连接的资源占用
// - 若执行过程中连接正在使用,将会切换至系统通道
func (srv *Server) ReleaseShunt(conn *Conn) {
srv.releaseDispatcher(conn)
}
// releaseDispatcher 关闭消息分发器
func (srv *Server) releaseDispatcher(conn *Conn) {
if conn == nil {
return
}
cid := conn.GetID()
srv.dispatcherLock.Lock()
defer srv.dispatcherLock.Unlock()
d, exist := srv.currDispatcher[cid]
if exist && d.name != serverSystemDispatcher {
delete(srv.dispatcherMember[d.name], cid)
if len(srv.dispatcherMember[d.name]) == 0 {
srv.OnShuntChannelClosedEvent(d.name)
d.close()
delete(srv.dispatchers, d.name)
}
delete(srv.currDispatcher, cid)
}
return srv.dispatcherMgr.GetDispatcherNum()
}
// pushMessage 向服务器中写入特定类型的消息,需严格遵守消息属性要求
@ -409,25 +317,29 @@ func (srv *Server) pushMessage(message *Message) {
srv.messagePool.Release(message)
return
}
var dispatcher *dispatcher
var d *dispatcher.Dispatcher[string, *Message]
switch message.t {
case MessageTypePacket,
MessageTypeShuntTicker, MessageTypeShuntAsync, MessageTypeShuntAsyncCallback,
MessageTypeUniqueShuntAsync, MessageTypeUniqueShuntAsyncCallback,
MessageTypeShunt:
dispatcher = srv.getConnDispatcher(message.conn)
d = srv.dispatcherMgr.GetDispatcher(message.conn.GetID())
case MessageTypeSystem, MessageTypeAsync, MessageTypeUniqueAsync, MessageTypeAsyncCallback, MessageTypeUniqueAsyncCallback, MessageTypeTicker:
dispatcher = srv.systemDispatcher
d = srv.dispatcherMgr.GetSystemDispatcher()
}
if dispatcher == nil {
if d == nil {
return
}
if (message.t == MessageTypeUniqueShuntAsync || message.t == MessageTypeUniqueAsync) && dispatcher.unique(message.name) {
if (message.t == MessageTypeUniqueShuntAsync || message.t == MessageTypeUniqueAsync) && d.Unique(message.name) {
srv.messagePool.Release(message)
return
}
switch message.t {
case MessageTypeShuntAsync, MessageTypeUniqueShuntAsync:
d.IncrCount(message.conn.GetID(), 1)
}
srv.hitMessageStatistics()
dispatcher.put(message)
d.Put(message)
}
func (srv *Server) low(message *Message, present time.Time, expect time.Duration, messageReplace ...string) {
@ -456,7 +368,7 @@ func (srv *Server) low(message *Message, present time.Time, expect time.Duration
}
// dispatchMessage 消息分发
func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) {
func (srv *Server) dispatchMessage(dispatcherIns *dispatcher.Dispatcher[string, *Message], msg *Message) {
var (
ctx context.Context
cancel context.CancelFunc
@ -476,7 +388,7 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) {
present := time.Now()
if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync {
defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher, msg *Message, present time.Time) {
defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher.Dispatcher[string, *Message], msg *Message, present time.Time) {
super.Handle(cancel)
if err := super.RecoverTransform(recover()); err != nil {
stack := string(debug.Stack())
@ -485,7 +397,7 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) {
srv.OnMessageErrorEvent(msg, err)
}
if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback {
dispatcherIns.antiUnique(msg.name)
dispatcherIns.AntiUnique(msg.name)
}
srv.low(msg, present, time.Millisecond*100)
@ -512,10 +424,14 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) {
msg.ordinaryHandler()
case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync:
if err := srv.ants.Submit(func() {
defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher, msg *Message, present time.Time) {
defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher.Dispatcher[string, *Message], msg *Message, present time.Time) {
switch msg.t {
case MessageTypeShuntAsync, MessageTypeUniqueShuntAsync:
dispatcherIns.IncrCount(msg.conn.GetID(), -1)
}
if err := super.RecoverTransform(recover()); err != nil {
if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync {
dispatcherIns.antiUnique(msg.name)
dispatcherIns.AntiUnique(msg.name)
}
stack := string(debug.Stack())
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack))
@ -550,7 +466,7 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) {
srv.PushShuntAsyncCallbackMessage(msg.conn, err, msg.errHandler)
return
}
dispatcherIns.antiUnique(msg.name)
dispatcherIns.AntiUnique(msg.name)
if err != nil {
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", string(debug.Stack())))
}
@ -769,7 +685,8 @@ func onMessageSystemInit(srv *Server) {
},
)
srv.startMessageStatistics()
srv.systemDispatcher = generateDispatcher(serverSystemDispatcher, srv.dispatchMessage)
go srv.systemDispatcher.start()
srv.dispatcherMgr = dispatcher.NewManager[string, *Message](srv.dispatcherBufferSize, srv.dispatchMessage).
SetDispatcherCreatedHandler(srv.OnShuntChannelCreatedEvent).
SetDispatcherClosedHandler(srv.OnShuntChannelClosedEvent)
srv.OnMessageReadyEvent()
}

View File

@ -18,7 +18,7 @@ func TestNew(t *testing.T) {
fmt.Println("启动完成")
})
srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) {
fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount())
fmt.Println("关闭", conn.GetID(), err, "IncrCount", srv.GetOnlineCount())
})
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
@ -38,7 +38,7 @@ func TestNew2(t *testing.T) {
fmt.Println("启动完成")
})
srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) {
fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount())
fmt.Println("关闭", conn.GetID(), err, "IncrCount", srv.GetOnlineCount())
})
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {

View File

@ -1,15 +1,21 @@
package buffer
// NewRing 创建一个环形缓冲区
func NewRing[T any](initSize int) *Ring[T] {
if initSize <= 1 {
panic("initial size must be great than one")
// NewRing 创建一个并发不安全的环形缓冲区
// - initSize: 初始容量
//
// 当初始容量小于 2 或未设置时,将会使用默认容量 2
func NewRing[T any](initSize ...int) *Ring[T] {
if len(initSize) == 0 {
initSize = append(initSize, 2)
}
if initSize[0] < 2 {
initSize[0] = 2
}
return &Ring[T]{
buf: make([]T, initSize),
initSize: initSize,
size: initSize,
buf: make([]T, initSize[0]),
initSize: initSize[0],
size: initSize[0],
}
}
@ -23,91 +29,120 @@ type Ring[T any] struct {
}
// Read 读取数据
func (slf *Ring[T]) Read() (T, error) {
func (b *Ring[T]) Read() (T, error) {
var t T
if slf.r == slf.w {
if b.r == b.w {
return t, ErrBufferIsEmpty
}
v := slf.buf[slf.r]
slf.r++
if slf.r == slf.size {
slf.r = 0
v := b.buf[b.r]
b.r++
if b.r == b.size {
b.r = 0
}
return v, nil
}
// ReadAll 读取所有数据
func (b *Ring[T]) ReadAll() []T {
if b.r == b.w {
return nil // 没有数据时返回空切片
}
var length int
var data []T
if b.w > b.r {
length = b.w - b.r
} else {
length = len(b.buf) - b.r + b.w
}
data = make([]T, length) // 预分配空间
if b.w > b.r {
copy(data, b.buf[b.r:b.w])
} else {
copied := copy(data, b.buf[b.r:])
copy(data[copied:], b.buf[:b.w])
}
b.r = 0
b.w = 0
return data
}
// Peek 查看数据
func (slf *Ring[T]) Peek() (t T, err error) {
if slf.r == slf.w {
func (b *Ring[T]) Peek() (t T, err error) {
if b.r == b.w {
return t, ErrBufferIsEmpty
}
return slf.buf[slf.r], nil
return b.buf[b.r], nil
}
// Write 写入数据
func (slf *Ring[T]) Write(v T) {
slf.buf[slf.w] = v
slf.w++
func (b *Ring[T]) Write(v T) {
b.buf[b.w] = v
b.w++
if slf.w == slf.size {
slf.w = 0
if b.w == b.size {
b.w = 0
}
if slf.w == slf.r {
slf.grow()
if b.w == b.r {
b.grow()
}
}
// grow 扩容
func (slf *Ring[T]) grow() {
func (b *Ring[T]) grow() {
var size int
if slf.size < 1024 {
size = slf.size * 2
if b.size < 1024 {
size = b.size * 2
} else {
size = slf.size + slf.size/4
size = b.size + b.size/4
}
buf := make([]T, size)
copy(buf[0:], slf.buf[slf.r:])
copy(buf[slf.size-slf.r:], slf.buf[0:slf.r])
copy(buf[0:], b.buf[b.r:])
copy(buf[b.size-b.r:], b.buf[0:b.r])
slf.r = 0
slf.w = slf.size
slf.size = size
slf.buf = buf
b.r = 0
b.w = b.size
b.size = size
b.buf = buf
}
// IsEmpty 是否为空
func (slf *Ring[T]) IsEmpty() bool {
return slf.r == slf.w
func (b *Ring[T]) IsEmpty() bool {
return b.r == b.w
}
// Cap 返回缓冲区容量
func (slf *Ring[T]) Cap() int {
return slf.size
func (b *Ring[T]) Cap() int {
return b.size
}
// Len 返回缓冲区长度
func (slf *Ring[T]) Len() int {
if slf.r == slf.w {
func (b *Ring[T]) Len() int {
if b.r == b.w {
return 0
}
if slf.w > slf.r {
return slf.w - slf.r
if b.w > b.r {
return b.w - b.r
}
return slf.size - slf.r + slf.w
return b.size - b.r + b.w
}
// Reset 重置缓冲区
func (slf *Ring[T]) Reset() {
slf.r = 0
slf.w = 0
slf.size = slf.initSize
slf.buf = make([]T, slf.initSize)
func (b *Ring[T]) Reset() {
b.r = 0
b.w = 0
b.size = b.initSize
b.buf = make([]T, b.initSize)
}

View File

@ -0,0 +1,25 @@
package buffer_test
import (
"github.com/kercylan98/minotaur/utils/buffer"
"testing"
)
func BenchmarkRingWrite(b *testing.B) {
ring := buffer.NewRing[int](1024)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ring.Write(i)
}
}
func BenchmarkRingRead(b *testing.B) {
ring := buffer.NewRing[int](1024)
for i := 0; i < b.N; i++ {
ring.Write(i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ring.Read()
}
}

14
utils/buffer/ring_test.go Normal file
View File

@ -0,0 +1,14 @@
package buffer_test
import (
"github.com/kercylan98/minotaur/utils/buffer"
"testing"
)
func TestNewRing(t *testing.T) {
ring := buffer.NewRing[int]()
for i := 0; i < 100; i++ {
ring.Write(i)
t.Log(ring.Read())
}
}

View File

@ -0,0 +1,100 @@
package buffer
import (
"sync"
)
// NewRingUnbounded 创建一个并发安全的基于环形缓冲区实现的无界缓冲区
func NewRingUnbounded[T any](bufferSize int) *RingUnbounded[T] {
ru := &RingUnbounded[T]{
ring: NewRing[T](1024),
rc: make(chan T, bufferSize),
closedSignal: make(chan struct{}),
}
ru.cond = sync.NewCond(&ru.rrm)
ru.process()
return ru
}
// RingUnbounded 基于环形缓冲区实现的无界缓冲区
type RingUnbounded[T any] struct {
ring *Ring[T]
rrm sync.Mutex
cond *sync.Cond
rc chan T
closed bool
closedMutex sync.RWMutex
closedSignal chan struct{}
}
// Write 写入数据
func (b *RingUnbounded[T]) Write(v T) {
b.closedMutex.RLock()
defer b.closedMutex.RUnlock()
if b.closed {
return
}
b.rrm.Lock()
b.ring.Write(v)
b.cond.Signal()
b.rrm.Unlock()
}
// Read 读取数据
func (b *RingUnbounded[T]) Read() <-chan T {
return b.rc
}
// Closed 判断缓冲区是否已关闭
func (b *RingUnbounded[T]) Closed() bool {
b.closedMutex.RLock()
defer b.closedMutex.RUnlock()
return b.closed
}
// Close 关闭缓冲区,关闭后将不再接收新数据,但是已有数据仍然可以读取
func (b *RingUnbounded[T]) Close() <-chan struct{} {
b.closedMutex.Lock()
defer b.closedMutex.Unlock()
if b.closed {
return b.closedSignal
}
b.closed = true
b.rrm.Lock()
b.cond.Signal()
b.rrm.Unlock()
return b.closedSignal
}
func (b *RingUnbounded[T]) process() {
go func(b *RingUnbounded[T]) {
for {
b.closedMutex.RLock()
b.rrm.Lock()
vs := b.ring.ReadAll()
if len(vs) == 0 && !b.closed {
b.closedMutex.RUnlock()
b.cond.Wait()
} else {
b.closedMutex.RUnlock()
}
b.rrm.Unlock()
b.closedMutex.RLock()
if b.closed && len(vs) == 0 {
close(b.rc)
close(b.closedSignal)
b.closedMutex.RUnlock()
break
}
for _, v := range vs {
b.rc <- v
}
b.closedMutex.RUnlock()
}
}(b)
}

View File

@ -0,0 +1,48 @@
package buffer_test
import (
"github.com/kercylan98/minotaur/utils/buffer"
"testing"
)
func BenchmarkRingUnbounded_Write(b *testing.B) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ring.Write(i)
}
}
func BenchmarkRingUnbounded_Read(b *testing.B) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
for i := 0; i < b.N; i++ {
ring.Write(i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
<-ring.Read()
}
}
func BenchmarkRingUnbounded_Write_Parallel(b *testing.B) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ring.Write(1)
}
})
}
func BenchmarkRingUnbounded_Read_Parallel(b *testing.B) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
for i := 0; i < b.N; i++ {
ring.Write(i)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
<-ring.Read()
}
})
}

View File

@ -0,0 +1,33 @@
package buffer_test
import (
"github.com/kercylan98/minotaur/utils/buffer"
"testing"
)
func TestRingUnbounded_Write2Read(t *testing.T) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
for i := 0; i < 100; i++ {
ring.Write(i)
}
t.Log("write done")
for i := 0; i < 100; i++ {
t.Log(<-ring.Read())
}
t.Log("read done")
}
func TestRingUnbounded_Close(t *testing.T) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
for i := 0; i < 100; i++ {
ring.Write(i)
}
t.Log("write done")
ring.Close()
t.Log("close done")
for v := range ring.Read() {
ring.Write(v)
t.Log(v)
}
t.Log("read done")
}

View File

@ -5,15 +5,46 @@ import (
"testing"
)
func BenchmarkUnboundedBuffer(b *testing.B) {
ub := buffer.NewUnbounded[int]()
func BenchmarkUnbounded_Write(b *testing.B) {
u := buffer.NewUnbounded[int]()
b.ResetTimer()
for i := 0; i < b.N; i++ {
u.Put(i)
}
}
func BenchmarkUnbounded_Read(b *testing.B) {
u := buffer.NewUnbounded[int]()
for i := 0; i < b.N; i++ {
u.Put(i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
<-u.Get()
u.Load()
}
}
func BenchmarkUnbounded_Write_Parallel(b *testing.B) {
u := buffer.NewUnbounded[int]()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ub.Put(1)
<-ub.Get()
ub.Load()
u.Put(1)
}
})
}
func BenchmarkUnbounded_Read_Parallel(b *testing.B) {
u := buffer.NewUnbounded[int]()
for i := 0; i < b.N; i++ {
u.Put(i)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
<-u.Get()
u.Load()
}
})
}