From 80f38ffe9c5a603e432ee00692c9b9bc35ac65c7 Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Fri, 29 Dec 2023 12:15:29 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20server.hub=20?= =?UTF-8?q?=E5=B9=BF=E6=92=AD=E6=97=B6=E6=9C=AA=E8=A7=A3=E9=94=81=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98=EF=BC=8C=E4=BC=98=E5=8C=96=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/hub.go | 75 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/server/hub.go b/server/hub.go index 1eebd8a..4c31ba9 100644 --- a/server/hub.go +++ b/server/hub.go @@ -17,6 +17,8 @@ type hub struct { onlineCount int // 在线人数 chanMutex sync.RWMutex // 避免外界函数导致的并发问题 + + closed bool } type hubBroadcast struct { @@ -33,39 +35,16 @@ func (h *hub) run(ctx context.Context) { for { select { case conn := <-h.register: - h.chanMutex.Lock() - h.connections[conn.GetID()] = conn - h.onlineCount++ - if conn.IsBot() { - h.botCount++ - } - h.chanMutex.Unlock() - case connId := <-h.unregister: - h.chanMutex.Lock() - if conn, ok := h.connections[connId]; ok { - h.onlineCount-- - delete(h.connections, conn.GetID()) - if conn.IsBot() { - h.botCount-- - } - } - h.chanMutex.Unlock() + h.onRegister(conn) + case id := <-h.unregister: + h.onUnregister(id) case packet := <-h.broadcast: - h.chanMutex.RLock() - for _, conn := range h.connections { - if packet.filter != nil && !packet.filter(conn) { - continue - } - conn.Write(packet.packet) - } - + h.onBroadcast(packet) case <-ctx.Done(): h.chanMutex.Lock() close(h.register) close(h.unregister) - h.connections = nil - h.botCount = 0 - h.onlineCount = 0 + h.closed = true h.chanMutex.Unlock() return @@ -79,6 +58,7 @@ func (h *hub) registerConn(conn *Conn) { select { case h.register <- conn: default: + h.onRegister(conn) } } @@ -87,6 +67,7 @@ func (h *hub) unregisterConn(id string) { select { case h.unregister <- id: default: + h.onUnregister(id) } } @@ -149,5 +130,43 @@ func (h *hub) Broadcast(packet []byte, filter ...func(conn *Conn) bool) { select { case h.broadcast <- m: default: + h.onBroadcast(m) + } +} + +func (h *hub) onRegister(conn *Conn) { + h.chanMutex.Lock() + if h.closed { + conn.Close() + return + } + h.connections[conn.GetID()] = conn + h.onlineCount++ + if conn.IsBot() { + h.botCount++ + } + h.chanMutex.Unlock() +} + +func (h *hub) onUnregister(id string) { + h.chanMutex.Lock() + if conn, ok := h.connections[id]; ok { + h.onlineCount-- + delete(h.connections, conn.GetID()) + if conn.IsBot() { + h.botCount-- + } + } + h.chanMutex.Unlock() +} + +func (h *hub) onBroadcast(packet hubBroadcast) { + h.chanMutex.RLock() + defer h.chanMutex.RUnlock() + for _, conn := range h.connections { + if packet.filter != nil && !packet.filter(conn) { + continue + } + conn.Write(packet.packet) } }