From 2d9ffad2ab0277c0a83842c3d27ca31b820a51de Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Tue, 19 Sep 2023 12:40:16 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20client=20=E5=8C=85=E9=87=87?= =?UTF-8?q?=E7=94=A8=E6=97=A0=E7=95=8C=E7=BC=93=E5=86=B2=E5=8C=BA=E6=9B=BF?= =?UTF-8?q?=E4=BB=A3=E9=80=9A=E8=BF=87=20chan=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E7=9A=84=E5=86=99=E9=80=9A=E9=81=93=EF=BC=8C=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E5=A0=86=E7=A7=AF=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/client/client.go | 169 ++++++++++++------------------------- server/client/websocket.go | 12 ++- 2 files changed, 64 insertions(+), 117 deletions(-) diff --git a/server/client/client.go b/server/client/client.go index cc42505..5b341a6 100644 --- a/server/client/client.go +++ b/server/client/client.go @@ -1,7 +1,9 @@ package client import ( + "errors" "fmt" + "github.com/kercylan98/minotaur/server/writeloop" "github.com/kercylan98/minotaur/utils/concurrent" "sync" ) @@ -11,28 +13,36 @@ func NewClient(core Core) *Client { client := &Client{ events: new(events), core: core, + closed: true, } return client } // CloneClient 克隆客户端 func CloneClient(client *Client) *Client { - return NewClient(client.core.Clone()) + cli := NewClient(client.core.Clone()) + return cli } // Client 客户端 type Client struct { *events - core Core - mutex sync.Mutex - packetPool *concurrent.Pool[*Packet] - packets chan *Packet - - accumulate chan *Packet - accumulation int // 积压消息数 + core Core + mutex sync.Mutex + closed bool // 是否已关闭 + pool *concurrent.Pool[*Packet] // 数据包缓冲池 + loop *writeloop.WriteLoop[*Packet] // 写入循环 } +// Run 运行客户端 +// - 当客户端已运行时,会先关闭客户端再重新运行 func (slf *Client) Run() error { + slf.mutex.Lock() + if !slf.closed { + slf.mutex.Unlock() + slf.Close() + slf.mutex.Lock() + } var runState = make(chan error) go func(runState chan<- error) { defer func() { @@ -44,49 +54,46 @@ func (slf *Client) Run() error { }(runState) err := <-runState if err != nil { - slf.mutex.Lock() - if slf.packetPool != nil { - slf.packetPool.Close() - slf.packetPool = nil - } slf.mutex.Unlock() return err } - var wait = new(sync.WaitGroup) - wait.Add(1) - go slf.writeLoop(wait) - wait.Wait() + slf.closed = false + slf.pool = concurrent.NewPool[*Packet](10*1024, func() *Packet { + return new(Packet) + }, func(data *Packet) { + data.wst = 0 + data.data = nil + data.callback = nil + }) + slf.loop = writeloop.NewWriteLoop[*Packet](slf.pool, func(message *Packet) error { + err := slf.core.Write(message) + if message.callback != nil { + message.callback(err) + } + return err + }, func(err any) { + slf.Close(errors.New(fmt.Sprint(err))) + }) + slf.mutex.Unlock() + slf.OnConnectionOpenedEvent(slf) return nil } // IsConnected 是否已连接 func (slf *Client) IsConnected() bool { - return slf.packetPool != nil + slf.mutex.Lock() + defer slf.mutex.Unlock() + return !slf.closed } // Close 关闭 func (slf *Client) Close(err ...error) { slf.mutex.Lock() - var unlock bool - defer func() { - if !unlock { - slf.mutex.Unlock() - } - }() + slf.closed = true slf.core.Close() - if slf.packetPool != nil { - slf.packetPool.Close() - slf.packetPool = nil - } - if slf.packets != nil { - close(slf.packets) - } - if slf.accumulate != nil { - close(slf.accumulate) - slf.accumulate = nil - } - unlock = true + slf.loop.Close() + slf.pool.Close() slf.mutex.Unlock() if len(err) > 0 { slf.OnConnectionClosedEvent(slf, err[0]) @@ -110,83 +117,18 @@ func (slf *Client) Write(packet []byte, callback ...func(err error)) { // - messageType: websocket模式中指定消息类型 func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) { slf.mutex.Lock() - if slf.packetPool == nil || slf.packets == nil { - var p = &Packet{ - wst: wst, - data: packet, - } - if len(callback) > 0 { - p.callback = callback[0] - } - if slf.accumulate == nil { - slf.accumulate = make(chan *Packet, 1024*10) - } - slf.accumulate <- p - } else { - cp := slf.packetPool.Get() - cp.wst = wst - cp.data = packet - if len(callback) > 0 { - cp.callback = callback[0] - } - slf.packets <- cp - slf.accumulation = len(slf.accumulate) + len(slf.packets) - } - slf.mutex.Unlock() -} - -// writeLoop 写循环 -func (slf *Client) writeLoop(wait *sync.WaitGroup) { - slf.mutex.Lock() - slf.packets = make(chan *Packet, 1024*10) - slf.packetPool = concurrent.NewPool[*Packet](10*1024, - func() *Packet { - return &Packet{} - }, func(data *Packet) { - data.wst = 0 - data.data = nil - data.callback = nil - }, - ) - go func() { - for packet := range slf.accumulate { - slf.packets <- packet - } - }() - defer func() { - if err := recover(); err != nil { - err, isErr := err.(error) - if !isErr { - err = fmt.Errorf("%v", err) - } - slf.Close(err) - slf.packets = nil - } - }() - wait.Done() - slf.mutex.Unlock() - - for { - packet, ok := <-slf.packets - if !ok { - slf.mutex.Lock() - slf.packets = nil - slf.mutex.Unlock() - break - } - - data := packet - var err = slf.core.Write(data) - callback := data.callback - slf.packetPool.Release(data) - if callback != nil { - callback(err) - } - if err != nil { - panic(err) - } + defer slf.mutex.Unlock() + if slf.closed { + return } + cp := slf.pool.Get() + cp.wst = wst + cp.data = packet + if len(callback) > 0 { + cp.callback = callback[0] + } + slf.loop.Put(cp) } func (slf *Client) onReceive(wst int, packet []byte) { @@ -197,8 +139,3 @@ func (slf *Client) onReceive(wst int, packet []byte) { func (slf *Client) GetServerAddr() string { return slf.core.GetServerAddr() } - -// GetMessageAccumulationTotal 获取消息积压总数 -func (slf *Client) GetMessageAccumulationTotal() int { - return slf.accumulation -} diff --git a/server/client/websocket.go b/server/client/websocket.go index 921c906..1218f73 100644 --- a/server/client/websocket.go +++ b/server/client/websocket.go @@ -3,6 +3,7 @@ package client import ( "github.com/gorilla/websocket" "github.com/kercylan98/minotaur/server" + "sync" ) // NewWebsocket 创建 websocket 客户端 @@ -17,6 +18,7 @@ type Websocket struct { addr string conn *websocket.Conn closed bool + mu sync.Mutex } func (slf *Websocket) Run(runState chan<- error, receive func(wst int, packet []byte)) { @@ -28,7 +30,13 @@ func (slf *Websocket) Run(runState chan<- error, receive func(wst int, packet [] slf.conn = ws slf.closed = false runState <- nil - for !slf.closed { + for { + slf.mu.Lock() + if slf.closed { + slf.mu.Unlock() + break + } + slf.mu.Unlock() messageType, packet, readErr := ws.ReadMessage() if readErr != nil { panic(readErr) @@ -45,6 +53,8 @@ func (slf *Websocket) Write(packet *Packet) error { } func (slf *Websocket) Close() { + slf.mu.Lock() + defer slf.mu.Unlock() slf.closed = true }