From 5e5fe8acca8b2ef7a302997b0211a3103415bdf9 Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Fri, 29 Dec 2023 10:34:18 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=20server.Server=20?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E7=AE=A1=E7=90=86=E6=9C=BA=E5=88=B6=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=20GetOnlineCount=E3=80=81GetOnlineBotCount?= =?UTF-8?q?=20=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/constants.go | 1 + server/event.go | 4 +- server/hub.go | 153 ++++++++++++++++++++++++++++++++++++++++++++ server/server.go | 52 +-------------- 4 files changed, 159 insertions(+), 51 deletions(-) create mode 100644 server/hub.go diff --git a/server/constants.go b/server/constants.go index 634dace..d08a16b 100644 --- a/server/constants.go +++ b/server/constants.go @@ -18,6 +18,7 @@ const ( DefaultPacketWarnSize = 1024 * 1024 * 1 // 1MB DefaultDispatcherBufferSize = 1024 * 16 DefaultConnWriteBufferSize = 1024 * 1 + DefaultConnHubBufferSize = 1024 * 1 ) func DefaultWebsocketUpgrader() *websocket.Upgrader { diff --git a/server/event.go b/server/event.go index 9471438..8268265 100644 --- a/server/event.go +++ b/server/event.go @@ -211,7 +211,7 @@ func (slf *event) RegConnectionClosedEvent(handler ConnectionClosedEventHandler, func (slf *event) OnConnectionClosedEvent(conn *Conn, err any) { slf.PushShuntMessage(conn, func() { - slf.Server.online.Del(conn.GetID()) + slf.unregisterConn(conn.GetID()) slf.connectionClosedEventHandlers.RangeValue(func(index int, value ConnectionClosedEventHandler) bool { value(slf.Server, conn, err) return true @@ -231,7 +231,7 @@ func (slf *event) RegConnectionOpenedEvent(handler ConnectionOpenedEventHandler, func (slf *event) OnConnectionOpenedEvent(conn *Conn) { slf.PushSystemMessage(func() { - slf.Server.online.Set(conn.GetID(), conn) + slf.registerConn(conn) slf.connectionOpenedEventHandlers.RangeValue(func(index int, value ConnectionOpenedEventHandler) bool { value(slf.Server, conn) return true diff --git a/server/hub.go b/server/hub.go new file mode 100644 index 0000000..1eebd8a --- /dev/null +++ b/server/hub.go @@ -0,0 +1,153 @@ +package server + +import ( + "context" + "github.com/kercylan98/minotaur/utils/hash" + "sync" +) + +type hub struct { + connections map[string]*Conn // 所有连接 + + register chan *Conn // 注册连接 + unregister chan string // 注销连接 + broadcast chan hubBroadcast // 广播消息 + + botCount int // 机器人数量 + onlineCount int // 在线人数 + + chanMutex sync.RWMutex // 避免外界函数导致的并发问题 +} + +type hubBroadcast struct { + packet []byte // 广播的数据包 + filter func(conn *Conn) bool // 过滤掉返回 false 的连接 +} + +func (h *hub) run(ctx context.Context) { + h.connections = make(map[string]*Conn) + h.register = make(chan *Conn, DefaultConnHubBufferSize) + h.unregister = make(chan string, DefaultConnHubBufferSize) + h.broadcast = make(chan hubBroadcast, DefaultConnHubBufferSize) + go func(ctx context.Context, h *hub) { + for { + select { + case conn := <-h.register: + h.chanMutex.Lock() + h.connections[conn.GetID()] = conn + h.onlineCount++ + if conn.IsBot() { + h.botCount++ + } + h.chanMutex.Unlock() + case connId := <-h.unregister: + h.chanMutex.Lock() + if conn, ok := h.connections[connId]; ok { + h.onlineCount-- + delete(h.connections, conn.GetID()) + if conn.IsBot() { + h.botCount-- + } + } + h.chanMutex.Unlock() + case packet := <-h.broadcast: + h.chanMutex.RLock() + for _, conn := range h.connections { + if packet.filter != nil && !packet.filter(conn) { + continue + } + conn.Write(packet.packet) + } + + case <-ctx.Done(): + h.chanMutex.Lock() + close(h.register) + close(h.unregister) + h.connections = nil + h.botCount = 0 + h.onlineCount = 0 + h.chanMutex.Unlock() + return + + } + } + }(ctx, h) +} + +// registerConn 注册连接 +func (h *hub) registerConn(conn *Conn) { + select { + case h.register <- conn: + default: + } +} + +// unregisterConn 注销连接 +func (h *hub) unregisterConn(id string) { + select { + case h.unregister <- id: + default: + } +} + +// GetOnlineCount 获取在线人数 +func (h *hub) GetOnlineCount() int { + h.chanMutex.RLock() + defer h.chanMutex.RUnlock() + return h.onlineCount +} + +// GetOnlineBotCount 获取在线机器人数量 +func (h *hub) GetOnlineBotCount() int { + h.chanMutex.RLock() + defer h.chanMutex.RUnlock() + return h.botCount +} + +// IsOnline 是否在线 +func (h *hub) IsOnline(id string) bool { + h.chanMutex.RLock() + _, exist := h.connections[id] + h.chanMutex.RUnlock() + return exist +} + +// GetOnlineAll 获取所有在线连接 +func (h *hub) GetOnlineAll() map[string]*Conn { + h.chanMutex.RLock() + cop := hash.Copy(h.connections) + h.chanMutex.RUnlock() + return cop +} + +// GetOnline 获取在线连接 +func (h *hub) GetOnline(id string) *Conn { + h.chanMutex.RLock() + conn := h.connections[id] + h.chanMutex.RUnlock() + return conn +} + +// CloseConn 关闭连接 +func (h *hub) CloseConn(id string) { + h.chanMutex.RLock() + conn := h.connections[id] + h.chanMutex.RUnlock() + if conn != nil { + conn.Close() + } +} + +// Broadcast 广播消息 +func (h *hub) Broadcast(packet []byte, filter ...func(conn *Conn) bool) { + m := hubBroadcast{ + packet: packet, + } + if len(filter) > 0 { + m.filter = filter[0] + } + select { + case h.broadcast <- m: + default: + } +} diff --git a/server/server.go b/server/server.go index 5a30509..3aff595 100644 --- a/server/server.go +++ b/server/server.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/alphadose/haxmap" "github.com/gin-gonic/gin" "github.com/kercylan98/minotaur/server/internal/logger" "github.com/kercylan98/minotaur/utils/concurrent" @@ -36,9 +35,9 @@ func New(network Network, options ...Option) *Server { dispatcherBufferSize: DefaultDispatcherBufferSize, connWriteBufferSize: DefaultConnWriteBufferSize, }, + hub: &hub{}, option: &option{}, network: network, - online: haxmap.New[string, *Conn](), closeChannel: make(chan struct{}, 1), systemSignal: make(chan os.Signal, 1), dispatchers: make(map[string]*dispatcher), @@ -71,6 +70,7 @@ type Server struct { *event // 事件 *runtime // 运行时 *option // 可选项 + *hub // 连接集合 ginServer *gin.Engine // HTTP模式下的路由器 httpServer *http.Server // HTTP模式下的服务器 grpcServer *grpc.Server // GRPC模式下的服务器 @@ -80,7 +80,6 @@ type Server struct { messagePool *concurrent.Pool[*Message] // 消息池 ctx context.Context // 上下文 cancel context.CancelFunc // 停止上下文 - online *haxmap.Map[string, *Conn] // 在线连接 systemDispatcher *dispatcher // 系统消息分发器 systemSignal chan os.Signal // 系统信号 closeChannel chan struct{} // 关闭信号 @@ -106,6 +105,7 @@ func (srv *Server) preCheckAndAdaptation(addr string) (startState <-chan error, kcp.SystemTimedSched.Close() } + srv.hub.run(srv.ctx) return srv.network.adaptation(srv), nil } @@ -174,52 +174,6 @@ func (srv *Server) TimeoutContext(timeout time.Duration) (context.Context, conte return context.WithTimeout(srv.ctx, timeout) } -// GetOnlineCount 获取在线人数 -func (srv *Server) GetOnlineCount() int { - return int(srv.online.Len()) -} - -// GetOnlineBotCount 获取在线机器人数量 -func (srv *Server) GetOnlineBotCount() int { - var count int - srv.online.ForEach(func(id string, conn *Conn) bool { - if conn.IsBot() { - count++ - } - return true - }) - return count -} - -// GetOnline 获取在线连接 -func (srv *Server) GetOnline(id string) *Conn { - c, _ := srv.online.Get(id) - return c -} - -// GetOnlineAll 获取所有在线连接 -func (srv *Server) GetOnlineAll() map[string]*Conn { - var m = map[string]*Conn{} - srv.online.ForEach(func(id string, conn *Conn) bool { - m[id] = conn - return true - }) - return m -} - -// IsOnline 是否在线 -func (srv *Server) IsOnline(id string) bool { - _, exist := srv.online.Get(id) - return exist -} - -// CloseConn 关闭连接 -func (srv *Server) CloseConn(id string) { - if conn, exist := srv.online.Get(id); exist { - conn.Close() - } -} - // Ticker 获取服务器定时器 func (srv *Server) Ticker() *timer.Ticker { if srv.ticker == nil {