diff --git a/server/hub.go b/server/hub.go index 1eebd8a..4c31ba9 100644 --- a/server/hub.go +++ b/server/hub.go @@ -17,6 +17,8 @@ type hub struct { onlineCount int // 在线人数 chanMutex sync.RWMutex // 避免外界函数导致的并发问题 + + closed bool } type hubBroadcast struct { @@ -33,39 +35,16 @@ func (h *hub) run(ctx context.Context) { for { select { case conn := <-h.register: - h.chanMutex.Lock() - h.connections[conn.GetID()] = conn - h.onlineCount++ - if conn.IsBot() { - h.botCount++ - } - h.chanMutex.Unlock() - case connId := <-h.unregister: - h.chanMutex.Lock() - if conn, ok := h.connections[connId]; ok { - h.onlineCount-- - delete(h.connections, conn.GetID()) - if conn.IsBot() { - h.botCount-- - } - } - h.chanMutex.Unlock() + h.onRegister(conn) + case id := <-h.unregister: + h.onUnregister(id) case packet := <-h.broadcast: - h.chanMutex.RLock() - for _, conn := range h.connections { - if packet.filter != nil && !packet.filter(conn) { - continue - } - conn.Write(packet.packet) - } - + h.onBroadcast(packet) case <-ctx.Done(): h.chanMutex.Lock() close(h.register) close(h.unregister) - h.connections = nil - h.botCount = 0 - h.onlineCount = 0 + h.closed = true h.chanMutex.Unlock() return @@ -79,6 +58,7 @@ func (h *hub) registerConn(conn *Conn) { select { case h.register <- conn: default: + h.onRegister(conn) } } @@ -87,6 +67,7 @@ func (h *hub) unregisterConn(id string) { select { case h.unregister <- id: default: + h.onUnregister(id) } } @@ -149,5 +130,43 @@ func (h *hub) Broadcast(packet []byte, filter ...func(conn *Conn) bool) { select { case h.broadcast <- m: default: + h.onBroadcast(m) + } +} + +func (h *hub) onRegister(conn *Conn) { + h.chanMutex.Lock() + if h.closed { + conn.Close() + return + } + h.connections[conn.GetID()] = conn + h.onlineCount++ + if conn.IsBot() { + h.botCount++ + } + h.chanMutex.Unlock() +} + +func (h *hub) onUnregister(id string) { + h.chanMutex.Lock() + if conn, ok := h.connections[id]; ok { + h.onlineCount-- + delete(h.connections, conn.GetID()) + if conn.IsBot() { + h.botCount-- + } + } + h.chanMutex.Unlock() +} + +func (h *hub) onBroadcast(packet hubBroadcast) { + h.chanMutex.RLock() + defer h.chanMutex.RUnlock() + for _, conn := range h.connections { + if packet.filter != nil && !packet.filter(conn) { + continue + } + conn.Write(packet.packet) } }