diff --git a/server/conn.go b/server/conn.go index 3f97832..afa73d7 100644 --- a/server/conn.go +++ b/server/conn.go @@ -127,6 +127,8 @@ type Conn struct { type connection struct { server *Server mutex *sync.Mutex + close sync.Once + closed bool remoteAddr net.Addr ip string ws *websocket.Conn @@ -159,23 +161,36 @@ func (slf *Conn) GetIP() string { return slf.ip } +// IsClosed 是否已经关闭 +func (slf *Conn) IsClosed() bool { + return slf.closed +} + // Close 关闭连接 -func (slf *Conn) Close() { - if slf.ws != nil { - _ = slf.ws.Close() - } else if slf.gn != nil { - _ = slf.gn.Close() - } else if slf.kcp != nil { - _ = slf.kcp.Close() - } - if slf.packetPool != nil { - slf.packetPool.Close() - } - slf.packetPool = nil - if slf.packets != nil { - close(slf.packets) - slf.packets = nil - } +func (slf *Conn) Close(err ...error) { + slf.close.Do(func() { + slf.closed = true + if slf.ws != nil { + _ = slf.ws.Close() + } else if slf.gn != nil { + _ = slf.gn.Close() + } else if slf.kcp != nil { + _ = slf.kcp.Close() + } + if slf.packetPool != nil { + slf.packetPool.Close() + } + slf.packetPool = nil + if slf.packets != nil { + close(slf.packets) + slf.packets = nil + } + if len(err) > 0 { + slf.server.OnConnectionClosedEvent(slf, err[0]) + return + } + slf.server.OnConnectionClosedEvent(slf, nil) + }) } // SetData 设置连接数据,该数据将在连接关闭前始终存在 @@ -261,9 +276,6 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) { defer func() { if err := recover(); err != nil { slf.Close() - // TODO: 以下代码是否需要? - // log.Error("WriteLoop", log.Any("Error", err)) - // debug.PrintStack() } }() wait.Done() diff --git a/server/event.go b/server/event.go index 1ae7241..e81e988 100644 --- a/server/event.go +++ b/server/event.go @@ -184,12 +184,11 @@ func (slf *event) RegConnectionClosedEvent(handle ConnectionClosedEventHandle, p func (slf *event) OnConnectionClosedEvent(conn *Conn, err any) { PushSystemMessage(slf.Server, func() { + slf.Server.online.Delete(conn.GetID()) slf.connectionClosedEventHandles.RangeValue(func(index int, value ConnectionClosedEventHandle) bool { value(slf.Server, conn, err) return true }) - conn.Close() - slf.Server.online.Delete(conn.GetID()) }, "ConnectionClosedEvent") } diff --git a/server/gnet.go b/server/gnet.go index f207568..4d7b61f 100644 --- a/server/gnet.go +++ b/server/gnet.go @@ -26,7 +26,8 @@ func (slf *gNet) OnOpened(c gnet.Conn) (out []byte, action gnet.Action) { } func (slf *gNet) OnClosed(c gnet.Conn, err error) (action gnet.Action) { - slf.OnConnectionClosedEvent(c.Context().(*Conn), err) + conn := c.Context().(*Conn) + conn.Close(err) return } diff --git a/server/server.go b/server/server.go index f91e000..f6fb3ac 100644 --- a/server/server.go +++ b/server/server.go @@ -205,14 +205,21 @@ func (slf *Server) Run(addr string) error { go func(conn *Conn) { defer func() { if err := recover(); err != nil { - slf.OnConnectionClosedEvent(conn, err) + e, ok := err.(error) + if !ok { + e = fmt.Errorf("%v", err) + } + conn.Close(e) } }() buf := make([]byte, 4096) - for { + for !conn.IsClosed() { n, err := conn.kcp.Read(buf) if err != nil { + if conn.IsClosed() { + break + } panic(err) } PushPacketMessage(slf, conn, 0, buf[:n]) @@ -292,15 +299,22 @@ func (slf *Server) Run(addr string) error { defer func() { if err := recover(); err != nil { - slf.OnConnectionClosedEvent(conn, err) + e, ok := err.(error) + if !ok { + e = fmt.Errorf("%v", err) + } + conn.Close(e) } }() - for { + for !conn.IsClosed() { if err := ws.SetReadDeadline(super.If(slf.websocketReadDeadline <= 0, times.Zero, time.Now().Add(slf.websocketReadDeadline))); err != nil { panic(err) } messageType, packet, readErr := ws.ReadMessage() if readErr != nil { + if conn.IsClosed() { + break + } panic(readErr) } if len(slf.supportMessageTypes) > 0 && !slf.supportMessageTypes[messageType] { diff --git a/server/server_test.go b/server/server_test.go index 8d0032c..a85dfed 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -21,14 +21,15 @@ func TestNew(t *testing.T) { } return true }) - var current *server.Conn + srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) { + fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount()) + }) srv.RegConnectionOpenedEvent(func(srv *server.Server, conn *server.Conn) { - if current != nil { - current.Reuse(conn) - } else { - current = conn + if srv.GetOnlineCount() > 1 { + conn.Close() } }) + srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { conn.Write(packet) })