From 08559d822506bc5695fafae531260ee69447d9bf Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Thu, 24 Aug 2023 10:54:25 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20client=20=E5=8C=85=E5=86=85=E5=AD=98?= =?UTF-8?q?=E6=BA=A2=E5=87=BA=E3=80=81=E6=AD=BB=E5=BE=AA=E7=8E=AF=E7=AD=89?= =?UTF-8?q?=E9=97=AE=E9=A2=98=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/client/client.go | 32 +++++++++++++++++++++++--------- server/client/websocket.go | 1 + server/gateway/endpoint.go | 8 ++++++-- server/gateway/gateway_test.go | 5 ----- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/server/client/client.go b/server/client/client.go index b380276..baf45d1 100644 --- a/server/client/client.go +++ b/server/client/client.go @@ -27,10 +27,6 @@ type Client struct { } func (slf *Client) Run() error { - var wait = new(sync.WaitGroup) - wait.Add(1) - go slf.writeLoop(wait) - wait.Wait() var runState = make(chan error) go func() { defer func() { @@ -42,9 +38,20 @@ func (slf *Client) Run() error { }() err := <-runState if err != nil { - slf.Close() + slf.mutex.Lock() + if slf.packetPool != nil { + slf.packetPool.Close() + slf.packetPool = nil + } + slf.accumulate = append(slf.accumulate, slf.packets...) + slf.packets = nil + slf.mutex.Unlock() return err } + var wait = new(sync.WaitGroup) + wait.Add(1) + go slf.writeLoop(wait) + wait.Wait() slf.OnConnectionOpenedEvent(slf) return nil } @@ -83,16 +90,23 @@ func (slf *Client) Write(packet []byte, callback ...func(err error)) { // write 向连接中写入数据 // - messageType: websocket模式中指定消息类型 func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) { + if slf.packetPool == nil { + var p = &Packet{ + wst: wst, + data: packet, + } + if len(callback) > 0 { + p.callback = callback[0] + } + slf.accumulate = append(slf.accumulate, p) + return + } cp := slf.packetPool.Get() cp.wst = wst cp.data = packet if len(callback) > 0 { cp.callback = callback[0] } - if slf.packetPool == nil { - slf.accumulate = append(slf.accumulate, cp) - return - } slf.mutex.Lock() slf.packets = append(slf.packets, cp) slf.mutex.Unlock() diff --git a/server/client/websocket.go b/server/client/websocket.go index 821019c..6c8683e 100644 --- a/server/client/websocket.go +++ b/server/client/websocket.go @@ -26,6 +26,7 @@ func (slf *Websocket) Run(runState chan<- error, receive func(wst int, packet [] return } slf.conn = ws + slf.clsoed = false runState <- nil for !slf.clsoed { messageType, packet, readErr := ws.ReadMessage() diff --git a/server/gateway/endpoint.go b/server/gateway/endpoint.go index b90f22f..aca3608 100644 --- a/server/gateway/endpoint.go +++ b/server/gateway/endpoint.go @@ -70,7 +70,7 @@ func (slf *Endpoint) Connect() { slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur)) break } - time.Sleep(100 * time.Millisecond) + time.Sleep(1000 * time.Millisecond) } } @@ -98,5 +98,9 @@ func (slf *Endpoint) onConnectionReceivePacket(conn *client.Client, wst int, pac panic(err) } slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime)) - slf.GetLink(addr).SetWST(wst).Write(packet) + cli := slf.GetLink(addr) + if cli == nil { + return + } + cli.SetWST(wst).Write(packet) } diff --git a/server/gateway/gateway_test.go b/server/gateway/gateway_test.go index 8b3edd0..43ce56c 100644 --- a/server/gateway/gateway_test.go +++ b/server/gateway/gateway_test.go @@ -1,7 +1,6 @@ package gateway_test import ( - "fmt" "github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server/client" "github.com/kercylan98/minotaur/server/gateway" @@ -11,9 +10,6 @@ import ( func TestGateway_RunEndpointServer(t *testing.T) { srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) - srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) { - fmt.Println(err) - }) srv.RegConnectionPacketPreprocessEvent(func(srv *server.Server, conn *server.Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) { addr, packet, err := gateway.UnmarshalGatewayOutPacket(packet) if err != nil { @@ -35,7 +31,6 @@ func TestGateway_RunEndpointServer(t *testing.T) { return packet }) srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { - fmt.Println("endpoint receive packet", string(packet)) conn.Write(packet) }) if err := srv.Run(":8889"); err != nil {