diff --git a/go.mod b/go.mod index f0213c7..04a3d53 100644 --- a/go.mod +++ b/go.mod @@ -50,7 +50,9 @@ require ( github.com/nats-io/nkeys v0.4.4 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/sasha-s/go-deadlock v0.3.1 // indirect github.com/smarty/assertions v1.15.0 // indirect github.com/templexxx/cpu v0.1.0 // indirect github.com/templexxx/xorsimd v0.4.2 // indirect diff --git a/go.sum b/go.sum index bb33f3b..143cb43 100644 --- a/go.sum +++ b/go.sum @@ -134,6 +134,8 @@ github.com/panjf2000/gnet v1.6.7/go.mod h1:KcOU7QsCaCBjeD5kyshBIamG3d9kAQtlob4Y0 github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ= +github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -145,6 +147,8 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/sasha-s/go-deadlock v0.3.1 h1:sqv7fDNShgjcaxkO0JNcOAlr8B9+cV5Ey/OB71efZx0= +github.com/sasha-s/go-deadlock v0.3.1/go.mod h1:F73l+cr82YSh10GxyRI6qZiCgK64VaZjwesgfQ1/iLM= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= diff --git a/server/client/client.go b/server/client/client.go index f24bf21..bdadcb3 100644 --- a/server/client/client.go +++ b/server/client/client.go @@ -8,7 +8,6 @@ import ( // NewClient 创建客户端 func NewClient(core Core) *Client { client := &Client{ - cond: sync.NewCond(&sync.Mutex{}), events: new(events), core: core, } @@ -24,11 +23,11 @@ func CloneClient(client *Client) *Client { type Client struct { *events core Core - cond *sync.Cond + mutex sync.Mutex packetPool *concurrent.Pool[*Packet] - packets []*Packet + packets chan *Packet - accumulate []*Packet + accumulate chan *Packet accumulation int // 积压消息数 } @@ -44,14 +43,12 @@ func (slf *Client) Run() error { }() err := <-runState if err != nil { - slf.cond.L.Lock() + slf.mutex.Lock() if slf.packetPool != nil { slf.packetPool.Close() slf.packetPool = nil } - slf.accumulate = append(slf.accumulate, slf.packets...) - slf.packets = nil - slf.cond.L.Unlock() + slf.mutex.Unlock() return err } var wait = new(sync.WaitGroup) @@ -69,12 +66,29 @@ func (slf *Client) IsConnected() bool { // Close 关闭 func (slf *Client) Close(err ...error) { + slf.mutex.Lock() + var unlock bool + defer func() { + if !unlock { + slf.mutex.Unlock() + } + }() slf.core.Close() if slf.packetPool != nil { slf.packetPool.Close() slf.packetPool = nil } + if slf.packets != nil { + close(slf.packets) + slf.packets = nil + } + if slf.accumulate != nil { + close(slf.accumulate) + slf.accumulate = nil + } slf.packets = nil + unlock = true + slf.mutex.Unlock() if len(err) > 0 { slf.OnConnectionClosedEvent(slf, err[0]) } else { @@ -96,7 +110,8 @@ func (slf *Client) Write(packet []byte, callback ...func(err error)) { // write 向连接中写入数据 // - messageType: websocket模式中指定消息类型 func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) { - if slf.packetPool == nil { + slf.mutex.Lock() + if slf.packetPool == nil || slf.packets == nil { var p = &Packet{ wst: wst, data: packet, @@ -104,27 +119,26 @@ func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) { if len(callback) > 0 { p.callback = callback[0] } - slf.cond.L.Lock() - slf.accumulate = append(slf.accumulate, p) + if slf.accumulate == nil { + slf.accumulate = make(chan *Packet, 1024*10) + } + slf.accumulate <- p + } else { + cp := slf.packetPool.Get() + cp.wst = wst + cp.data = packet + if len(callback) > 0 { + cp.callback = callback[0] + } + slf.packets <- cp slf.accumulation = len(slf.accumulate) + len(slf.packets) - slf.cond.L.Unlock() - return } - cp := slf.packetPool.Get() - cp.wst = wst - cp.data = packet - if len(callback) > 0 { - cp.callback = callback[0] - } - slf.cond.L.Lock() - slf.packets = append(slf.packets, cp) - slf.accumulation = len(slf.accumulate) + len(slf.packets) - slf.cond.Signal() - slf.cond.L.Unlock() + slf.mutex.Unlock() } // writeLoop 写循环 func (slf *Client) writeLoop(wait *sync.WaitGroup) { + slf.packets = make(chan *Packet, 1024*10) slf.packetPool = concurrent.NewPool[*Packet](10*1024, func() *Packet { return &Packet{} @@ -134,10 +148,11 @@ func (slf *Client) writeLoop(wait *sync.WaitGroup) { data.callback = nil }, ) - slf.cond.L.Lock() - slf.packets = append(slf.packets, slf.accumulate...) - slf.accumulate = nil - slf.cond.L.Unlock() + go func() { + for packet := range slf.accumulate { + slf.packets <- packet + } + }() defer func() { if err := recover(); err != nil { slf.Close(err.(error)) @@ -145,31 +160,19 @@ func (slf *Client) writeLoop(wait *sync.WaitGroup) { }() wait.Done() - for { - slf.cond.L.Lock() - if slf.packetPool == nil { - slf.cond.L.Unlock() - return + for packet := range slf.packets { + data := packet + var err = slf.core.Write(data) + callback := data.callback + slf.packetPool.Release(data) + if callback != nil { + callback(err) } - if len(slf.packets) == 0 { - slf.cond.Wait() - } - packets := slf.packets[0:] - slf.packets = slf.packets[0:0] - slf.cond.L.Unlock() - for i := 0; i < len(packets); i++ { - data := packets[i] - var err = slf.core.Write(data) - callback := data.callback - slf.packetPool.Release(data) - if callback != nil { - callback(err) - } - if err != nil { - panic(err) - } + if err != nil { + panic(err) } } + } func (slf *Client) onReceive(wst int, packet []byte) { diff --git a/server/conn.go b/server/conn.go index de25fa8..0dda0ab 100644 --- a/server/conn.go +++ b/server/conn.go @@ -18,7 +18,8 @@ func newKcpConn(server *Server, session *kcp.UDPSession) *Conn { c := &Conn{ ctx: server.ctx, connection: &connection{ - cond: sync.NewCond(&sync.Mutex{}), + packets: make(chan *connPacket, 1024*10), + mutex: new(sync.Mutex), server: server, remoteAddr: session.RemoteAddr(), ip: session.RemoteAddr().String(), @@ -41,7 +42,8 @@ func newGNetConn(server *Server, conn gnet.Conn) *Conn { c := &Conn{ ctx: server.ctx, connection: &connection{ - cond: sync.NewCond(&sync.Mutex{}), + packets: make(chan *connPacket, 1024*10), + mutex: new(sync.Mutex), server: server, remoteAddr: conn.RemoteAddr(), ip: conn.RemoteAddr().String(), @@ -64,7 +66,8 @@ func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn { c := &Conn{ ctx: server.ctx, connection: &connection{ - cond: sync.NewCond(&sync.Mutex{}), + packets: make(chan *connPacket, 1024*10), + mutex: new(sync.Mutex), server: server, remoteAddr: ws.RemoteAddr(), ip: ip, @@ -84,9 +87,10 @@ func newGatewayConn(conn *Conn, connId string) *Conn { c := &Conn{ //ctx: server.ctx, connection: &connection{ - cond: sync.NewCond(&sync.Mutex{}), - server: conn.server, - data: map[any]any{}, + packets: make(chan *connPacket, 1024*10), + mutex: new(sync.Mutex), + server: conn.server, + data: map[any]any{}, }, } c.gw = func(packet []byte) { @@ -100,7 +104,8 @@ func NewEmptyConn(server *Server) *Conn { c := &Conn{ ctx: server.ctx, connection: &connection{ - cond: sync.NewCond(&sync.Mutex{}), + packets: make(chan *connPacket, 1024*10), + mutex: new(sync.Mutex), server: server, remoteAddr: &net.TCPAddr{}, ip: "0.0.0.0:0", @@ -123,6 +128,7 @@ type Conn struct { // connection 长久保持的连接 type connection struct { server *Server + mutex *sync.Mutex remoteAddr net.Addr ip string ws *websocket.Conn @@ -130,9 +136,8 @@ type connection struct { kcp *kcp.UDPSession gw func(packet []byte) data map[any]any - cond *sync.Cond packetPool *concurrent.Pool[*connPacket] - packets []*connPacket + packets chan *connPacket } // IsEmpty 是否是空连接 @@ -144,12 +149,6 @@ func (slf *Conn) IsEmpty() bool { // - 重用连接时,会将当前连接的数据复制到新连接中 // - 通常在于连接断开后,重新连接时使用 func (slf *Conn) Reuse(conn *Conn) { - slf.cond.L.Lock() - conn.cond.L.Lock() - defer func() { - slf.cond.L.Unlock() - conn.cond.L.Unlock() - }() slf.Close() slf.remoteAddr = conn.remoteAddr slf.ip = conn.ip @@ -190,7 +189,10 @@ func (slf *Conn) Close() { slf.packetPool.Close() } slf.packetPool = nil - slf.packets = nil + if slf.packets != nil { + close(slf.packets) + slf.packets = nil + } } // SetData 设置连接数据,该数据将在连接关闭前始终存在 @@ -243,12 +245,14 @@ func (slf *Conn) SetWST(wst int) *Conn { // Write 向连接中写入数据 // - messageType: websocket模式中指定消息类型 func (slf *Conn) Write(packet []byte, callback ...func(err error)) { + slf.mutex.Lock() + defer slf.mutex.Unlock() if slf.gw != nil { slf.gw(packet) return } packet = slf.server.OnConnectionWritePacketBeforeEvent(slf, packet) - if slf.packetPool == nil { + if slf.packetPool == nil || slf.packets == nil { return } cp := slf.packetPool.Get() @@ -257,10 +261,7 @@ func (slf *Conn) Write(packet []byte, callback ...func(err error)) { if len(callback) > 0 { cp.callback = callback[0] } - slf.cond.L.Lock() - slf.packets = append(slf.packets, cp) - slf.cond.Signal() - slf.cond.L.Unlock() + slf.packets <- cp } // writeLoop 写循环 @@ -282,44 +283,32 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) { } }() wait.Done() - for { - slf.cond.L.Lock() - if slf.packetPool == nil { - slf.cond.L.Unlock() - return - } - if len(slf.packets) == 0 { - slf.cond.Wait() - } - packets := slf.packets[0:] - slf.packets = slf.packets[0:0] - slf.cond.L.Unlock() - for i := 0; i < len(packets); i++ { - data := packets[i] - var err error - if slf.IsWebsocket() { - err = slf.ws.WriteMessage(data.wst, data.packet) - } else { - if slf.gn != nil { - switch slf.server.network { - case NetworkUdp, NetworkUdp4, NetworkUdp6: - err = slf.gn.SendTo(data.packet) - default: - err = slf.gn.AsyncWrite(data.packet) - } + for packet := range slf.packets { - } else if slf.kcp != nil { - _, err = slf.kcp.Write(data.packet) + data := packet + var err error + if slf.IsWebsocket() { + err = slf.ws.WriteMessage(data.wst, data.packet) + } else { + if slf.gn != nil { + switch slf.server.network { + case NetworkUdp, NetworkUdp4, NetworkUdp6: + err = slf.gn.SendTo(data.packet) + default: + err = slf.gn.AsyncWrite(data.packet) } + } else if slf.kcp != nil { + _, err = slf.kcp.Write(data.packet) } - callback := data.callback - slf.packetPool.Release(data) - if callback != nil { - callback(err) - } - if err != nil { - panic(err) - } + } + callback := data.callback + slf.packetPool.Release(data) + + if callback != nil { + callback(err) + } + if err != nil { + panic(err) } } } diff --git a/server/constants.go b/server/constants.go index b5033b0..f2119c8 100644 --- a/server/constants.go +++ b/server/constants.go @@ -22,7 +22,7 @@ const ( const ( DefaultMessageBufferSize = 1024 - DefaultMessageChannelSize = 1024 * 64 + DefaultMessageChannelSize = 1024 * 1024 DefaultAsyncPoolSize = 256 DefaultWebsocketReadDeadline = 30 * time.Second ) diff --git a/server/message.go b/server/message.go index 4b32257..64d8ed0 100644 --- a/server/message.go +++ b/server/message.go @@ -86,7 +86,9 @@ func (slf *Message) String() string { var s string switch slf.t { case MessageTypePacket: - s = messagePacketVisualization(attrs[1].([]byte)) + if len(attrs) > 1 { + s = messagePacketVisualization(attrs[1].([]byte)) + } default: if len(slf.attrs) == 0 { s = "NoneAttr" diff --git a/server/server.go b/server/server.go index e0072ff..79263b9 100644 --- a/server/server.go +++ b/server/server.go @@ -142,12 +142,13 @@ func (slf *Server) Run(addr string) error { if callback != nil { go callback() } - go func() { + go func(messageChannel <-chan *Message) { messageInitFinish <- struct{}{} - for message := range slf.messageChannel { - slf.dispatchMessage(message) + for message := range messageChannel { + msg := message + slf.dispatchMessage(msg) } - }() + }(slf.messageChannel) } switch slf.network { @@ -568,6 +569,7 @@ func (slf *Server) pushMessage(message *Message) { } } slf.messageChannel <- message + } func (slf *Server) low(message *Message, present time.Time, expect time.Duration, messageReplace ...string) { @@ -593,18 +595,18 @@ func (slf *Server) dispatchMessage(msg *Message) { ) if slf.deadlockDetect > 0 { ctx, cancel = context.WithTimeout(context.Background(), slf.deadlockDetect) - go func() { + go func(ctx context.Context, msg *Message) { select { case <-ctx.Done(): if err := ctx.Err(); err == context.DeadlineExceeded { log.Warn("Server", log.String("MessageType", messageNames[msg.t]), log.Any("SuspectedDeadlock", msg.attrs)) } } - }() + }(ctx, msg) } present := time.Now() - defer func() { + defer func(msg *Message) { if err := recover(); err != nil { stack := string(debug.Stack()) log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("MessageAttrs", msg.attrs), log.Any("error", err), log.String("stack", stack)) @@ -626,7 +628,7 @@ func (slf *Server) dispatchMessage(msg *Message) { slf.messagePool.Release(msg) } - }() + }(msg) var attrs = msg.attrs switch msg.t { case MessageTypePacket: diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..1d76b07 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,44 @@ +package server_test + +import ( + "fmt" + "github.com/kercylan98/minotaur/server" + "github.com/kercylan98/minotaur/server/client" + "github.com/kercylan98/minotaur/utils/times" + "sync/atomic" + "testing" + "time" +) + +func TestNew(t *testing.T) { + srv := server.New(server.NetworkWebsocket, server.WithMessageBufferSize(1024*1024), server.WithPProf()) + srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { + conn.Write(packet) + }) + if err := srv.Run(":9999"); err != nil { + panic(err) + } +} + +func TestNewClient(t *testing.T) { + var total atomic.Int64 + for i := 0; i < 1000; i++ { + cli := client.NewWebsocket("ws://127.0.0.1:9999") + cli.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) { + fmt.Println(string(packet)) + }) + cli.RegConnectionOpenedEvent(func(conn *client.Client) { + go func() { + for { + cli.WriteWS(2, []byte("hello")) + total.Add(1) + } + }() + }) + if err := cli.Run(); err != nil { + panic(err) + } + } + + time.Sleep(times.Week) +}