From 3402c83fd44e6c71db4401ef49667cada293c9c9 Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Fri, 5 Jan 2024 18:44:57 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20server=20=E5=8C=85?= =?UTF-8?q?=E9=83=A8=E5=88=86=E9=97=AE=E9=A2=98=EF=BC=8C=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=20log=20=E5=8C=85=E5=9C=A8=20init=20=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E5=8F=AF=E8=83=BD=E4=BA=A7=E7=94=9F=E7=9A=84?= =?UTF-8?q?=E7=A9=BA=E6=8C=87=E9=92=88=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/conn.go | 8 ++-- server/dispatcher.go | 95 ++++++++++++++++---------------------------- server/options.go | 17 ++++---- server/server.go | 10 ++--- utils/log/default.go | 10 ++--- 5 files changed, 54 insertions(+), 86 deletions(-) diff --git a/server/conn.go b/server/conn.go index 3f96a6c..ad5ee9f 100644 --- a/server/conn.go +++ b/server/conn.go @@ -358,11 +358,6 @@ func (slf *Conn) Close(err ...error) { } if slf.ticker != nil { slf.ticker.Release() - } - if !slf.server.runtime.disableAutomaticReleaseShunt { - slf.server.releaseDispatcher(slf) - } else { - } slf.loop.Close() slf.mu.Unlock() @@ -371,4 +366,7 @@ func (slf *Conn) Close(err ...error) { return } slf.server.OnConnectionClosedEvent(slf, nil) + if !slf.server.runtime.disableAutomaticReleaseShunt { + slf.server.releaseDispatcher(slf) + } } diff --git a/server/dispatcher.go b/server/dispatcher.go index 3a94d41..31c8e73 100644 --- a/server/dispatcher.go +++ b/server/dispatcher.go @@ -1,38 +1,33 @@ package server import ( - "context" "github.com/alphadose/haxmap" - "sync" + "github.com/kercylan98/minotaur/utils/buffer" + "sync/atomic" + "time" ) var dispatcherUnique = struct{}{} // generateDispatcher 生成消息分发器 -func generateDispatcher(size int, name string, handler func(dispatcher *dispatcher, message *Message)) *dispatcher { +func generateDispatcher(name string, handler func(dispatcher *dispatcher, message *Message)) *dispatcher { d := &dispatcher{ - name: name, - buffer: make(chan *Message, size), - handler: handler, - uniques: haxmap.New[string, struct{}](), - queueMutex: new(sync.Mutex), + name: name, + buf: buffer.NewUnbounded[*Message](), + handler: handler, + uniques: haxmap.New[string, struct{}](), } - d.ctx, d.cancel = context.WithCancel(context.Background()) - d.queueCond = sync.NewCond(d.queueMutex) return d } // dispatcher 消息分发器 type dispatcher struct { - name string - buffer chan *Message - uniques *haxmap.Map[string, struct{}] - handler func(dispatcher *dispatcher, message *Message) - ctx context.Context - cancel context.CancelFunc - queue []*Message - queueMutex *sync.Mutex - queueCond *sync.Cond + name string + buf *buffer.Unbounded[*Message] + uniques *haxmap.Map[string, struct{}] + handler func(dispatcher *dispatcher, message *Message) + closed uint32 + msgCount int64 } func (d *dispatcher) unique(name string) bool { @@ -45,66 +40,44 @@ func (d *dispatcher) antiUnique(name string) { } func (d *dispatcher) start() { - d.process() + defer d.buf.Close() for { select { - case message, ok := <-d.buffer: + case message, ok := <-d.buf.Get(): if !ok { return } + d.buf.Load() d.handler(d, message) - } - } -} -func (d *dispatcher) process() { - go func(ctx context.Context) { - for { - select { - case <-ctx.Done(): + if atomic.AddInt64(&d.msgCount, -1) <= 0 && atomic.LoadUint32(&d.closed) == 1 { return - default: - d.queueMutex.Lock() - if len(d.queue) == 0 { - d.queueCond.Wait() - } - messages := make([]*Message, len(d.queue)) - copy(messages, d.queue) - d.queue = d.queue[:0] - d.queueMutex.Unlock() - for _, message := range messages { - select { - case d.buffer <- message: - } - } } } - }(d.ctx) + } } func (d *dispatcher) put(message *Message) { - d.queueMutex.Lock() - d.queue = append(d.queue, message) - d.queueCond.Signal() - defer d.queueMutex.Unlock() + if atomic.CompareAndSwapUint32(&d.closed, 1, 1) { + return + } + atomic.AddInt64(&d.msgCount, 1) + d.buf.Put(message) } func (d *dispatcher) close() { - close(d.buffer) - d.cancel() -} + atomic.CompareAndSwapUint32(&d.closed, 0, 1) -func (d *dispatcher) transfer(target *dispatcher) { - if target == nil { - return - } - for { - select { - case message, ok := <-d.buffer: - if !ok { + go func() { + for { + if d.buf.IsClosed() { return } - target.buffer <- message + if atomic.LoadInt64(&d.msgCount) <= 0 { + d.buf.Close() + return + } + time.Sleep(time.Second) } - } + }() } diff --git a/server/options.go b/server/options.go index dbe14c7..8e923cc 100644 --- a/server/options.go +++ b/server/options.go @@ -48,7 +48,6 @@ type runtime struct { messageStatisticsLimit int // 消息统计数量 messageStatistics []*atomic.Int64 // 消息统计数量 messageStatisticsLock *sync.RWMutex // 消息统计锁 - dispatcherBufferSize int // 消息分发器缓冲区大小 connWriteBufferSize int // 连接写入缓冲区大小 disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道 websocketUpgrader *websocket.Upgrader // websocket 升级器 @@ -101,14 +100,14 @@ func WithConnWriteBufferSize(size int) Option { // WithDispatcherBufferSize 通过消息分发器缓冲区大小的方式创建服务器 // - 默认值为 DefaultDispatcherBufferSize // - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存 -func WithDispatcherBufferSize(size int) Option { - return func(srv *Server) { - if size <= 0 { - return - } - srv.dispatcherBufferSize = size - } -} +//func WithDispatcherBufferSize(size int) Option { +// return func(srv *Server) { +// if size <= 0 { +// return +// } +// srv.dispatcherBufferSize = size +// } +//} // WithMessageStatistics 通过消息统计的方式创建服务器 // - 默认不开启,当 duration 和 limit 均大于 0 的时候,服务器将记录每 duration 期间的消息数量,并保留最多 limit 条 diff --git a/server/server.go b/server/server.go index e8e0ded..33d7064 100644 --- a/server/server.go +++ b/server/server.go @@ -32,9 +32,8 @@ func New(network Network, options ...Option) *Server { network.check() server := &Server{ runtime: &runtime{ - packetWarnSize: DefaultPacketWarnSize, - dispatcherBufferSize: DefaultDispatcherBufferSize, - connWriteBufferSize: DefaultConnWriteBufferSize, + packetWarnSize: DefaultPacketWarnSize, + connWriteBufferSize: DefaultConnWriteBufferSize, }, hub: &hub{}, option: &option{}, @@ -307,7 +306,7 @@ func (srv *Server) UseShunt(conn *Conn, name string) { defer srv.dispatcherLock.Unlock() d, exist := srv.dispatchers[name] if !exist { - d = generateDispatcher(srv.dispatcherBufferSize, name, srv.dispatchMessage) + d = generateDispatcher(name, srv.dispatchMessage) srv.OnShuntChannelCreatedEvent(d.name) go d.start() srv.dispatchers[name] = d @@ -322,7 +321,6 @@ func (srv *Server) UseShunt(conn *Conn, name string) { delete(srv.dispatcherMember[curr.name], conn.GetID()) if curr.name != serverSystemDispatcher && len(srv.dispatcherMember[curr.name]) == 0 { delete(srv.dispatchers, curr.name) - curr.transfer(d) srv.OnShuntChannelClosedEvent(d.name) curr.close() } @@ -771,7 +769,7 @@ func onMessageSystemInit(srv *Server) { }, ) srv.startMessageStatistics() - srv.systemDispatcher = generateDispatcher(srv.dispatcherBufferSize, serverSystemDispatcher, srv.dispatchMessage) + srv.systemDispatcher = generateDispatcher(serverSystemDispatcher, srv.dispatchMessage) go srv.systemDispatcher.start() srv.OnMessageReadyEvent() } diff --git a/utils/log/default.go b/utils/log/default.go index dcc0fef..0dc573f 100644 --- a/utils/log/default.go +++ b/utils/log/default.go @@ -9,11 +9,11 @@ import ( "time" ) -var logger atomic.Pointer[Logger] - -func init() { - logger.Store(NewLogger(NewHandler(os.Stdout, NewOptions()))) -} +var logger = func() *atomic.Pointer[Logger] { + var p atomic.Pointer[Logger] + p.Store(NewLogger(NewHandler(os.Stdout, NewOptions()))) + return &p +}() // Default 获取默认的日志记录器 func Default() *Logger {