diff --git a/server/client/websocket.go b/server/client/websocket.go index 52a7e98..0fbee61 100644 --- a/server/client/websocket.go +++ b/server/client/websocket.go @@ -28,6 +28,8 @@ type Websocket struct { mutex sync.Mutex packetPool *concurrent.Pool[*websocketPacket] packets []*websocketPacket + + accumulate []server.Packet } // Run 启动 @@ -88,6 +90,7 @@ func (slf *Websocket) SetData(key string, value any) { // - messageType: websocket模式中指定消息类型 func (slf *Websocket) Write(packet server.Packet) { if slf.packetPool == nil { + slf.accumulate = append(slf.accumulate, packet) return } cp := slf.packetPool.Get() @@ -109,6 +112,15 @@ func (slf *Websocket) writeLoop(wait *sync.WaitGroup) { data.callback = nil }, ) + slf.mutex.Lock() + for _, packet := range slf.accumulate { + cp := slf.packetPool.Get() + cp.websocketMessageType = packet.WebsocketType + cp.packet = packet.Data + slf.packets = append(slf.packets, cp) + } + slf.accumulate = nil + slf.mutex.Unlock() defer func() { if err := recover(); err != nil { slf.Close() @@ -118,6 +130,7 @@ func (slf *Websocket) writeLoop(wait *sync.WaitGroup) { for { slf.mutex.Lock() if slf.packetPool == nil { + slf.mutex.Unlock() return } if len(slf.packets) == 0 { diff --git a/server/gateway/gateway.go b/server/gateway/gateway.go index 21a7d58..51b4995 100644 --- a/server/gateway/gateway.go +++ b/server/gateway/gateway.go @@ -55,5 +55,8 @@ func (slf *Gateway) onConnectionReceivePacket(srv *server.Server, conn *server.C } pd := super.MarshalJSON(&gp) packet.Data = append(pd, 0xff) - conn.GetData("endpoint").(*Endpoint).Write(packet) + var endpoint, exist = conn.GetData("endpoint").(*Endpoint) + if exist { + endpoint.Write(packet) + } } diff --git a/server/gateway/gateway_test.go b/server/gateway/gateway_test.go index 08b9951..744162e 100644 --- a/server/gateway/gateway_test.go +++ b/server/gateway/gateway_test.go @@ -5,10 +5,11 @@ import ( "github.com/kercylan98/minotaur/server" gateway2 "github.com/kercylan98/minotaur/server/gateway" "testing" + "time" ) func TestGateway_RunEndpointServer(t *testing.T) { - srv := server.New(server.NetworkWebsocket) + srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet server.Packet) { fmt.Println("endpoint receive packet", string(packet.Data)) conn.Write(packet) @@ -19,7 +20,7 @@ func TestGateway_RunEndpointServer(t *testing.T) { } func TestGateway_Run(t *testing.T) { - srv := server.New(server.NetworkWebsocket) + srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) gw := gateway2.NewGateway(srv) srv.RegStartFinishEvent(func(srv *server.Server) { if err := gw.AddEndpoint(gateway2.NewEndpoint("test", "ws://127.0.0.1:8889")); err != nil {