From 7bf4e82183c7f1b259e12cf329796812d5da296f Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Thu, 17 Aug 2023 19:26:54 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20websocket=20?= =?UTF-8?q?=E5=AE=A2=E6=88=B7=E7=AB=AF=E6=AD=BB=E9=94=81=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/client/websocket.go | 13 +++++++++++++ server/gateway/gateway.go | 5 ++++- server/gateway/gateway_test.go | 5 +++-- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/server/client/websocket.go b/server/client/websocket.go index 52a7e98..0fbee61 100644 --- a/server/client/websocket.go +++ b/server/client/websocket.go @@ -28,6 +28,8 @@ type Websocket struct { mutex sync.Mutex packetPool *concurrent.Pool[*websocketPacket] packets []*websocketPacket + + accumulate []server.Packet } // Run 启动 @@ -88,6 +90,7 @@ func (slf *Websocket) SetData(key string, value any) { // - messageType: websocket模式中指定消息类型 func (slf *Websocket) Write(packet server.Packet) { if slf.packetPool == nil { + slf.accumulate = append(slf.accumulate, packet) return } cp := slf.packetPool.Get() @@ -109,6 +112,15 @@ func (slf *Websocket) writeLoop(wait *sync.WaitGroup) { data.callback = nil }, ) + slf.mutex.Lock() + for _, packet := range slf.accumulate { + cp := slf.packetPool.Get() + cp.websocketMessageType = packet.WebsocketType + cp.packet = packet.Data + slf.packets = append(slf.packets, cp) + } + slf.accumulate = nil + slf.mutex.Unlock() defer func() { if err := recover(); err != nil { slf.Close() @@ -118,6 +130,7 @@ func (slf *Websocket) writeLoop(wait *sync.WaitGroup) { for { slf.mutex.Lock() if slf.packetPool == nil { + slf.mutex.Unlock() return } if len(slf.packets) == 0 { diff --git a/server/gateway/gateway.go b/server/gateway/gateway.go index 21a7d58..51b4995 100644 --- a/server/gateway/gateway.go +++ b/server/gateway/gateway.go @@ -55,5 +55,8 @@ func (slf *Gateway) onConnectionReceivePacket(srv *server.Server, conn *server.C } pd := super.MarshalJSON(&gp) packet.Data = append(pd, 0xff) - conn.GetData("endpoint").(*Endpoint).Write(packet) + var endpoint, exist = conn.GetData("endpoint").(*Endpoint) + if exist { + endpoint.Write(packet) + } } diff --git a/server/gateway/gateway_test.go b/server/gateway/gateway_test.go index 08b9951..744162e 100644 --- a/server/gateway/gateway_test.go +++ b/server/gateway/gateway_test.go @@ -5,10 +5,11 @@ import ( "github.com/kercylan98/minotaur/server" gateway2 "github.com/kercylan98/minotaur/server/gateway" "testing" + "time" ) func TestGateway_RunEndpointServer(t *testing.T) { - srv := server.New(server.NetworkWebsocket) + srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet server.Packet) { fmt.Println("endpoint receive packet", string(packet.Data)) conn.Write(packet) @@ -19,7 +20,7 @@ func TestGateway_RunEndpointServer(t *testing.T) { } func TestGateway_Run(t *testing.T) { - srv := server.New(server.NetworkWebsocket) + srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) gw := gateway2.NewGateway(srv) srv.RegStartFinishEvent(func(srv *server.Server) { if err := gw.AddEndpoint(gateway2.NewEndpoint("test", "ws://127.0.0.1:8889")); err != nil {