fix: 修复 websocket 客户端死锁问题

This commit is contained in:
kercylan98 2023-08-17 19:26:54 +08:00
parent 351257033e
commit 7bf4e82183
3 changed files with 20 additions and 3 deletions

View File

@ -28,6 +28,8 @@ type Websocket struct {
mutex sync.Mutex mutex sync.Mutex
packetPool *concurrent.Pool[*websocketPacket] packetPool *concurrent.Pool[*websocketPacket]
packets []*websocketPacket packets []*websocketPacket
accumulate []server.Packet
} }
// Run 启动 // Run 启动
@ -88,6 +90,7 @@ func (slf *Websocket) SetData(key string, value any) {
// - messageType: websocket模式中指定消息类型 // - messageType: websocket模式中指定消息类型
func (slf *Websocket) Write(packet server.Packet) { func (slf *Websocket) Write(packet server.Packet) {
if slf.packetPool == nil { if slf.packetPool == nil {
slf.accumulate = append(slf.accumulate, packet)
return return
} }
cp := slf.packetPool.Get() cp := slf.packetPool.Get()
@ -109,6 +112,15 @@ func (slf *Websocket) writeLoop(wait *sync.WaitGroup) {
data.callback = nil data.callback = nil
}, },
) )
slf.mutex.Lock()
for _, packet := range slf.accumulate {
cp := slf.packetPool.Get()
cp.websocketMessageType = packet.WebsocketType
cp.packet = packet.Data
slf.packets = append(slf.packets, cp)
}
slf.accumulate = nil
slf.mutex.Unlock()
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
slf.Close() slf.Close()
@ -118,6 +130,7 @@ func (slf *Websocket) writeLoop(wait *sync.WaitGroup) {
for { for {
slf.mutex.Lock() slf.mutex.Lock()
if slf.packetPool == nil { if slf.packetPool == nil {
slf.mutex.Unlock()
return return
} }
if len(slf.packets) == 0 { if len(slf.packets) == 0 {

View File

@ -55,5 +55,8 @@ func (slf *Gateway) onConnectionReceivePacket(srv *server.Server, conn *server.C
} }
pd := super.MarshalJSON(&gp) pd := super.MarshalJSON(&gp)
packet.Data = append(pd, 0xff) packet.Data = append(pd, 0xff)
conn.GetData("endpoint").(*Endpoint).Write(packet) var endpoint, exist = conn.GetData("endpoint").(*Endpoint)
if exist {
endpoint.Write(packet)
}
} }

View File

@ -5,10 +5,11 @@ import (
"github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server"
gateway2 "github.com/kercylan98/minotaur/server/gateway" gateway2 "github.com/kercylan98/minotaur/server/gateway"
"testing" "testing"
"time"
) )
func TestGateway_RunEndpointServer(t *testing.T) { func TestGateway_RunEndpointServer(t *testing.T) {
srv := server.New(server.NetworkWebsocket) srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3))
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet server.Packet) { srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet server.Packet) {
fmt.Println("endpoint receive packet", string(packet.Data)) fmt.Println("endpoint receive packet", string(packet.Data))
conn.Write(packet) conn.Write(packet)
@ -19,7 +20,7 @@ func TestGateway_RunEndpointServer(t *testing.T) {
} }
func TestGateway_Run(t *testing.T) { func TestGateway_Run(t *testing.T) {
srv := server.New(server.NetworkWebsocket) srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3))
gw := gateway2.NewGateway(srv) gw := gateway2.NewGateway(srv)
srv.RegStartFinishEvent(func(srv *server.Server) { srv.RegStartFinishEvent(func(srv *server.Server) {
if err := gw.AddEndpoint(gateway2.NewEndpoint("test", "ws://127.0.0.1:8889")); err != nil { if err := gw.AddEndpoint(gateway2.NewEndpoint("test", "ws://127.0.0.1:8889")); err != nil {