diff --git a/server/http_router.go b/server/http_router.go index 484561f..7a1c1f2 100644 --- a/server/http_router.go +++ b/server/http_router.go @@ -20,7 +20,7 @@ func (slf *HttpRouter[Context]) handlesConvert(handlers []HandlerFunc[Context]) for i := 0; i < len(handlers); i++ { handler := handlers[i] handles = append(handles, func(ctx *gin.Context) { - slf.srv.messageCounter.Add(1) + slf.srv.hitMessageStatistics() defer func() { slf.srv.messageCounter.Add(-1) }() diff --git a/server/options.go b/server/options.go index 61f5f0c..3df0b0b 100644 --- a/server/options.go +++ b/server/options.go @@ -5,6 +5,8 @@ import ( "github.com/kercylan98/minotaur/utils/log" "github.com/kercylan98/minotaur/utils/timer" "google.golang.org/grpc" + "sync" + "sync/atomic" "time" ) @@ -28,19 +30,36 @@ type option struct { } type runtime struct { - deadlockDetect time.Duration // 是否开启死锁检测 - supportMessageTypes map[int]bool // websocket模式下支持的消息类型 - certFile, keyFile string // TLS文件 - messagePoolSize int // 消息池大小 - tickerPool *timer.Pool // 定时器池 - ticker *timer.Ticker // 定时器 - tickerAutonomy bool // 定时器是否独立运行 - connTickerSize int // 连接定时器大小 - websocketReadDeadline time.Duration // websocket连接超时时间 - websocketCompression int // websocket压缩等级 - websocketWriteCompression bool // websocket写入压缩 - limitLife time.Duration // 限制最大生命周期 - packetWarnSize int // 数据包大小警告 + deadlockDetect time.Duration // 是否开启死锁检测 + supportMessageTypes map[int]bool // websocket模式下支持的消息类型 + certFile, keyFile string // TLS文件 + messagePoolSize int // 消息池大小 + tickerPool *timer.Pool // 定时器池 + ticker *timer.Ticker // 定时器 + tickerAutonomy bool // 定时器是否独立运行 + connTickerSize int // 连接定时器大小 + websocketReadDeadline time.Duration // websocket连接超时时间 + websocketCompression int // websocket压缩等级 + websocketWriteCompression bool // websocket写入压缩 + limitLife time.Duration // 限制最大生命周期 + packetWarnSize int // 数据包大小警告 + messageStatisticsDuration time.Duration // 消息统计时长 + messageStatisticsLimit int // 消息统计数量 + messageStatistics []*atomic.Int64 // 消息统计数量 + messageStatisticsLock *sync.RWMutex // 消息统计锁 +} + +// WithMessageStatistics 通过消息统计的方式创建服务器 +// - 默认不开启,当 duration 和 limit 均大于 0 的时候,服务器将记录每 duration 期间的消息数量,并保留最多 limit 条 +func WithMessageStatistics(duration time.Duration, limit int) Option { + return func(srv *Server) { + if duration <= 0 || limit <= 0 { + return + } + srv.messageStatisticsDuration = duration + srv.messageStatisticsLimit = limit + srv.messageStatisticsLock = new(sync.RWMutex) + } } // WithPacketWarnSize 通过数据包大小警告的方式创建服务器,当数据包大小超过指定大小时,将会输出 WARN 类型的日志 diff --git a/server/server.go b/server/server.go index 2a700ec..2b0f924 100644 --- a/server/server.go +++ b/server/server.go @@ -46,6 +46,7 @@ func New(network Network, options ...Option) *Server { dispatcherMember: map[string]map[string]*Conn{}, currDispatcher: map[string]*dispatcher{}, } + server.ctx, server.cancel = context.WithCancel(server.ctx) server.event = newEvent(server) switch network { @@ -92,6 +93,7 @@ type Server struct { ants *ants.Pool // 协程池 messagePool *concurrent.Pool[*Message] // 消息池 ctx context.Context // 上下文 + cancel context.CancelFunc // 停止上下文 online *concurrent.BalanceMap[string, *Conn] // 在线连接 systemDispatcher *dispatcher // 系统消息分发器 network Network // 网络类型 @@ -130,6 +132,7 @@ func (slf *Server) Run(addr string) error { } slf.event.check() slf.addr = addr + slf.startMessageStatistics() slf.systemDispatcher = generateDispatcher(serverSystemDispatcher, slf.dispatchMessage) var protoAddr = fmt.Sprintf("%s://%s", slf.network, slf.addr) var messageInitFinish = make(chan struct{}, 1) @@ -488,6 +491,7 @@ func (slf *Server) shutdown(err error) { slf.multipleRuntimeErrorChan <- err } }() + slf.cancel() if slf.gServer != nil && slf.isRunning { if shutdownErr := gnet.Stop(context.Background(), fmt.Sprintf("%s://%s", slf.network, slf.addr)); err != nil { log.Error("Server", log.Err(shutdownErr)) @@ -669,7 +673,7 @@ func (slf *Server) pushMessage(message *Message) { slf.messagePool.Release(message) return } - slf.messageCounter.Add(1) + slf.hitMessageStatistics() dispatcher.put(message) } @@ -918,3 +922,79 @@ func (slf *Server) PushErrorMessage(err error, errAction MessageErrorAction, mar func (slf *Server) PushShuntMessage(conn *Conn, caller func(), mark ...log.Field) { slf.pushMessage(slf.messagePool.Get().castToShuntMessage(conn, caller, mark...)) } + +// startMessageStatistics 开始消息统计 +func (slf *Server) startMessageStatistics() { + if !slf.HasMessageStatistics() { + return + } + slf.runtime.messageStatistics = append(slf.runtime.messageStatistics, new(atomic.Int64)) + ticker := time.NewTicker(slf.runtime.messageStatisticsDuration) + go func(ctx context.Context, ticker *time.Ticker, r *runtime) { + defer ticker.Stop() + for { + select { + case <-ticker.C: + r.messageStatisticsLock.Lock() + r.messageStatistics = append([]*atomic.Int64{new(atomic.Int64)}, r.messageStatistics...) + if len(r.messageStatistics) > r.messageStatisticsLimit { + r.messageStatistics = r.messageStatistics[:r.messageStatisticsLimit] + } + r.messageStatisticsLock.Unlock() + case <-ctx.Done(): + return + } + } + }(slf.ctx, ticker, slf.runtime) +} + +// hitMessageStatistics 命中消息统计 +func (slf *Server) hitMessageStatistics() { + slf.messageCounter.Add(1) + if !slf.HasMessageStatistics() { + return + } + slf.runtime.messageStatisticsLock.RLock() + slf.runtime.messageStatistics[0].Add(1) + slf.runtime.messageStatisticsLock.RUnlock() +} + +// GetDurationMessageCount 获取当前 WithMessageStatistics 设置的 duration 期间的消息量 +func (slf *Server) GetDurationMessageCount() int64 { + return slf.GetDurationMessageCountByOffset(0) +} + +// GetDurationMessageCountByOffset 获取特定偏移次数的 WithMessageStatistics 设置的 duration 期间的消息量 +// - 该值小于 0 时,将与 GetDurationMessageCount 无异,否则将返回 +n 个期间的消息量,例如 duration 为 1 分钟,limit 为 10,那么 offset 为 1 的情况下,获取的则是上一分钟消息量 +func (slf *Server) GetDurationMessageCountByOffset(offset int) int64 { + if !slf.HasMessageStatistics() { + return 0 + } + slf.runtime.messageStatisticsLock.Lock() + if offset >= len(slf.runtime.messageStatistics)-1 { + slf.runtime.messageStatisticsLock.Unlock() + return 0 + } + v := slf.runtime.messageStatistics[offset].Load() + slf.runtime.messageStatisticsLock.Unlock() + return v +} + +// GetAllDurationMessageCount 获取所有 WithMessageStatistics 设置的 duration 期间的消息量 +func (slf *Server) GetAllDurationMessageCount() []int64 { + if !slf.HasMessageStatistics() { + return nil + } + slf.runtime.messageStatisticsLock.Lock() + var vs = make([]int64, len(slf.runtime.messageStatistics)) + for i, statistic := range slf.runtime.messageStatistics { + vs[i] = statistic.Load() + } + slf.runtime.messageStatisticsLock.Unlock() + return vs +} + +// HasMessageStatistics 是否了开启消息统计 +func (slf *Server) HasMessageStatistics() bool { + return slf.runtime.messageStatisticsLock != nil +}