fix: 修复 server 包部分问题,修复 log 包在 init 函数调用可能产生的空指针问题

This commit is contained in:
kercylan98 2024-01-05 18:44:57 +08:00
parent b633f1af9f
commit 3402c83fd4
5 changed files with 54 additions and 86 deletions

View File

@ -358,11 +358,6 @@ func (slf *Conn) Close(err ...error) {
} }
if slf.ticker != nil { if slf.ticker != nil {
slf.ticker.Release() slf.ticker.Release()
}
if !slf.server.runtime.disableAutomaticReleaseShunt {
slf.server.releaseDispatcher(slf)
} else {
} }
slf.loop.Close() slf.loop.Close()
slf.mu.Unlock() slf.mu.Unlock()
@ -371,4 +366,7 @@ func (slf *Conn) Close(err ...error) {
return return
} }
slf.server.OnConnectionClosedEvent(slf, nil) slf.server.OnConnectionClosedEvent(slf, nil)
if !slf.server.runtime.disableAutomaticReleaseShunt {
slf.server.releaseDispatcher(slf)
}
} }

View File

@ -1,38 +1,33 @@
package server package server
import ( import (
"context"
"github.com/alphadose/haxmap" "github.com/alphadose/haxmap"
"sync" "github.com/kercylan98/minotaur/utils/buffer"
"sync/atomic"
"time"
) )
var dispatcherUnique = struct{}{} var dispatcherUnique = struct{}{}
// generateDispatcher 生成消息分发器 // generateDispatcher 生成消息分发器
func generateDispatcher(size int, name string, handler func(dispatcher *dispatcher, message *Message)) *dispatcher { func generateDispatcher(name string, handler func(dispatcher *dispatcher, message *Message)) *dispatcher {
d := &dispatcher{ d := &dispatcher{
name: name, name: name,
buffer: make(chan *Message, size), buf: buffer.NewUnbounded[*Message](),
handler: handler, handler: handler,
uniques: haxmap.New[string, struct{}](), uniques: haxmap.New[string, struct{}](),
queueMutex: new(sync.Mutex),
} }
d.ctx, d.cancel = context.WithCancel(context.Background())
d.queueCond = sync.NewCond(d.queueMutex)
return d return d
} }
// dispatcher 消息分发器 // dispatcher 消息分发器
type dispatcher struct { type dispatcher struct {
name string name string
buffer chan *Message buf *buffer.Unbounded[*Message]
uniques *haxmap.Map[string, struct{}] uniques *haxmap.Map[string, struct{}]
handler func(dispatcher *dispatcher, message *Message) handler func(dispatcher *dispatcher, message *Message)
ctx context.Context closed uint32
cancel context.CancelFunc msgCount int64
queue []*Message
queueMutex *sync.Mutex
queueCond *sync.Cond
} }
func (d *dispatcher) unique(name string) bool { func (d *dispatcher) unique(name string) bool {
@ -45,66 +40,44 @@ func (d *dispatcher) antiUnique(name string) {
} }
func (d *dispatcher) start() { func (d *dispatcher) start() {
d.process() defer d.buf.Close()
for { for {
select { select {
case message, ok := <-d.buffer: case message, ok := <-d.buf.Get():
if !ok { if !ok {
return return
} }
d.buf.Load()
d.handler(d, message) d.handler(d, message)
}
}
}
func (d *dispatcher) process() { if atomic.AddInt64(&d.msgCount, -1) <= 0 && atomic.LoadUint32(&d.closed) == 1 {
go func(ctx context.Context) {
for {
select {
case <-ctx.Done():
return return
default:
d.queueMutex.Lock()
if len(d.queue) == 0 {
d.queueCond.Wait()
}
messages := make([]*Message, len(d.queue))
copy(messages, d.queue)
d.queue = d.queue[:0]
d.queueMutex.Unlock()
for _, message := range messages {
select {
case d.buffer <- message:
} }
} }
} }
}
}(d.ctx)
} }
func (d *dispatcher) put(message *Message) { func (d *dispatcher) put(message *Message) {
d.queueMutex.Lock() if atomic.CompareAndSwapUint32(&d.closed, 1, 1) {
d.queue = append(d.queue, message) return
d.queueCond.Signal() }
defer d.queueMutex.Unlock() atomic.AddInt64(&d.msgCount, 1)
d.buf.Put(message)
} }
func (d *dispatcher) close() { func (d *dispatcher) close() {
close(d.buffer) atomic.CompareAndSwapUint32(&d.closed, 0, 1)
d.cancel()
}
func (d *dispatcher) transfer(target *dispatcher) { go func() {
if target == nil {
return
}
for { for {
select { if d.buf.IsClosed() {
case message, ok := <-d.buffer:
if !ok {
return return
} }
target.buffer <- message if atomic.LoadInt64(&d.msgCount) <= 0 {
d.buf.Close()
return
} }
time.Sleep(time.Second)
} }
}()
} }

View File

@ -48,7 +48,6 @@ type runtime struct {
messageStatisticsLimit int // 消息统计数量 messageStatisticsLimit int // 消息统计数量
messageStatistics []*atomic.Int64 // 消息统计数量 messageStatistics []*atomic.Int64 // 消息统计数量
messageStatisticsLock *sync.RWMutex // 消息统计锁 messageStatisticsLock *sync.RWMutex // 消息统计锁
dispatcherBufferSize int // 消息分发器缓冲区大小
connWriteBufferSize int // 连接写入缓冲区大小 connWriteBufferSize int // 连接写入缓冲区大小
disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道 disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道
websocketUpgrader *websocket.Upgrader // websocket 升级器 websocketUpgrader *websocket.Upgrader // websocket 升级器
@ -101,14 +100,14 @@ func WithConnWriteBufferSize(size int) Option {
// WithDispatcherBufferSize 通过消息分发器缓冲区大小的方式创建服务器 // WithDispatcherBufferSize 通过消息分发器缓冲区大小的方式创建服务器
// - 默认值为 DefaultDispatcherBufferSize // - 默认值为 DefaultDispatcherBufferSize
// - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存 // - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存
func WithDispatcherBufferSize(size int) Option { //func WithDispatcherBufferSize(size int) Option {
return func(srv *Server) { // return func(srv *Server) {
if size <= 0 { // if size <= 0 {
return // return
} // }
srv.dispatcherBufferSize = size // srv.dispatcherBufferSize = size
} // }
} //}
// WithMessageStatistics 通过消息统计的方式创建服务器 // WithMessageStatistics 通过消息统计的方式创建服务器
// - 默认不开启,当 duration 和 limit 均大于 0 的时候,服务器将记录每 duration 期间的消息数量,并保留最多 limit 条 // - 默认不开启,当 duration 和 limit 均大于 0 的时候,服务器将记录每 duration 期间的消息数量,并保留最多 limit 条

View File

@ -33,7 +33,6 @@ func New(network Network, options ...Option) *Server {
server := &Server{ server := &Server{
runtime: &runtime{ runtime: &runtime{
packetWarnSize: DefaultPacketWarnSize, packetWarnSize: DefaultPacketWarnSize,
dispatcherBufferSize: DefaultDispatcherBufferSize,
connWriteBufferSize: DefaultConnWriteBufferSize, connWriteBufferSize: DefaultConnWriteBufferSize,
}, },
hub: &hub{}, hub: &hub{},
@ -307,7 +306,7 @@ func (srv *Server) UseShunt(conn *Conn, name string) {
defer srv.dispatcherLock.Unlock() defer srv.dispatcherLock.Unlock()
d, exist := srv.dispatchers[name] d, exist := srv.dispatchers[name]
if !exist { if !exist {
d = generateDispatcher(srv.dispatcherBufferSize, name, srv.dispatchMessage) d = generateDispatcher(name, srv.dispatchMessage)
srv.OnShuntChannelCreatedEvent(d.name) srv.OnShuntChannelCreatedEvent(d.name)
go d.start() go d.start()
srv.dispatchers[name] = d srv.dispatchers[name] = d
@ -322,7 +321,6 @@ func (srv *Server) UseShunt(conn *Conn, name string) {
delete(srv.dispatcherMember[curr.name], conn.GetID()) delete(srv.dispatcherMember[curr.name], conn.GetID())
if curr.name != serverSystemDispatcher && len(srv.dispatcherMember[curr.name]) == 0 { if curr.name != serverSystemDispatcher && len(srv.dispatcherMember[curr.name]) == 0 {
delete(srv.dispatchers, curr.name) delete(srv.dispatchers, curr.name)
curr.transfer(d)
srv.OnShuntChannelClosedEvent(d.name) srv.OnShuntChannelClosedEvent(d.name)
curr.close() curr.close()
} }
@ -771,7 +769,7 @@ func onMessageSystemInit(srv *Server) {
}, },
) )
srv.startMessageStatistics() srv.startMessageStatistics()
srv.systemDispatcher = generateDispatcher(srv.dispatcherBufferSize, serverSystemDispatcher, srv.dispatchMessage) srv.systemDispatcher = generateDispatcher(serverSystemDispatcher, srv.dispatchMessage)
go srv.systemDispatcher.start() go srv.systemDispatcher.start()
srv.OnMessageReadyEvent() srv.OnMessageReadyEvent()
} }

View File

@ -9,11 +9,11 @@ import (
"time" "time"
) )
var logger atomic.Pointer[Logger] var logger = func() *atomic.Pointer[Logger] {
var p atomic.Pointer[Logger]
func init() { p.Store(NewLogger(NewHandler(os.Stdout, NewOptions())))
logger.Store(NewLogger(NewHandler(os.Stdout, NewOptions()))) return &p
} }()
// Default 获取默认的日志记录器 // Default 获取默认的日志记录器
func Default() *Logger { func Default() *Logger {