refactor: client 包采用无界缓冲区替代通过 chan 实现的写通道,移除消息堆积功能,优化代码逻辑

This commit is contained in:
kercylan98 2023-09-19 12:40:16 +08:00
parent dd1acfd017
commit 2d9ffad2ab
2 changed files with 64 additions and 117 deletions

View File

@ -1,7 +1,9 @@
package client package client
import ( import (
"errors"
"fmt" "fmt"
"github.com/kercylan98/minotaur/server/writeloop"
"github.com/kercylan98/minotaur/utils/concurrent" "github.com/kercylan98/minotaur/utils/concurrent"
"sync" "sync"
) )
@ -11,28 +13,36 @@ func NewClient(core Core) *Client {
client := &Client{ client := &Client{
events: new(events), events: new(events),
core: core, core: core,
closed: true,
} }
return client return client
} }
// CloneClient 克隆客户端 // CloneClient 克隆客户端
func CloneClient(client *Client) *Client { func CloneClient(client *Client) *Client {
return NewClient(client.core.Clone()) cli := NewClient(client.core.Clone())
return cli
} }
// Client 客户端 // Client 客户端
type Client struct { type Client struct {
*events *events
core Core core Core
mutex sync.Mutex mutex sync.Mutex
packetPool *concurrent.Pool[*Packet] closed bool // 是否已关闭
packets chan *Packet pool *concurrent.Pool[*Packet] // 数据包缓冲池
loop *writeloop.WriteLoop[*Packet] // 写入循环
accumulate chan *Packet
accumulation int // 积压消息数
} }
// Run 运行客户端
// - 当客户端已运行时,会先关闭客户端再重新运行
func (slf *Client) Run() error { func (slf *Client) Run() error {
slf.mutex.Lock()
if !slf.closed {
slf.mutex.Unlock()
slf.Close()
slf.mutex.Lock()
}
var runState = make(chan error) var runState = make(chan error)
go func(runState chan<- error) { go func(runState chan<- error) {
defer func() { defer func() {
@ -44,49 +54,46 @@ func (slf *Client) Run() error {
}(runState) }(runState)
err := <-runState err := <-runState
if err != nil { if err != nil {
slf.mutex.Lock()
if slf.packetPool != nil {
slf.packetPool.Close()
slf.packetPool = nil
}
slf.mutex.Unlock() slf.mutex.Unlock()
return err return err
} }
var wait = new(sync.WaitGroup) slf.closed = false
wait.Add(1) slf.pool = concurrent.NewPool[*Packet](10*1024, func() *Packet {
go slf.writeLoop(wait) return new(Packet)
wait.Wait() }, func(data *Packet) {
data.wst = 0
data.data = nil
data.callback = nil
})
slf.loop = writeloop.NewWriteLoop[*Packet](slf.pool, func(message *Packet) error {
err := slf.core.Write(message)
if message.callback != nil {
message.callback(err)
}
return err
}, func(err any) {
slf.Close(errors.New(fmt.Sprint(err)))
})
slf.mutex.Unlock()
slf.OnConnectionOpenedEvent(slf) slf.OnConnectionOpenedEvent(slf)
return nil return nil
} }
// IsConnected 是否已连接 // IsConnected 是否已连接
func (slf *Client) IsConnected() bool { func (slf *Client) IsConnected() bool {
return slf.packetPool != nil slf.mutex.Lock()
defer slf.mutex.Unlock()
return !slf.closed
} }
// Close 关闭 // Close 关闭
func (slf *Client) Close(err ...error) { func (slf *Client) Close(err ...error) {
slf.mutex.Lock() slf.mutex.Lock()
var unlock bool slf.closed = true
defer func() {
if !unlock {
slf.mutex.Unlock()
}
}()
slf.core.Close() slf.core.Close()
if slf.packetPool != nil { slf.loop.Close()
slf.packetPool.Close() slf.pool.Close()
slf.packetPool = nil
}
if slf.packets != nil {
close(slf.packets)
}
if slf.accumulate != nil {
close(slf.accumulate)
slf.accumulate = nil
}
unlock = true
slf.mutex.Unlock() slf.mutex.Unlock()
if len(err) > 0 { if len(err) > 0 {
slf.OnConnectionClosedEvent(slf, err[0]) slf.OnConnectionClosedEvent(slf, err[0])
@ -110,83 +117,18 @@ func (slf *Client) Write(packet []byte, callback ...func(err error)) {
// - messageType: websocket模式中指定消息类型 // - messageType: websocket模式中指定消息类型
func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) { func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) {
slf.mutex.Lock() slf.mutex.Lock()
if slf.packetPool == nil || slf.packets == nil { defer slf.mutex.Unlock()
var p = &Packet{ if slf.closed {
wst: wst, return
data: packet,
}
if len(callback) > 0 {
p.callback = callback[0]
}
if slf.accumulate == nil {
slf.accumulate = make(chan *Packet, 1024*10)
}
slf.accumulate <- p
} else {
cp := slf.packetPool.Get()
cp.wst = wst
cp.data = packet
if len(callback) > 0 {
cp.callback = callback[0]
}
slf.packets <- cp
slf.accumulation = len(slf.accumulate) + len(slf.packets)
}
slf.mutex.Unlock()
}
// writeLoop 写循环
func (slf *Client) writeLoop(wait *sync.WaitGroup) {
slf.mutex.Lock()
slf.packets = make(chan *Packet, 1024*10)
slf.packetPool = concurrent.NewPool[*Packet](10*1024,
func() *Packet {
return &Packet{}
}, func(data *Packet) {
data.wst = 0
data.data = nil
data.callback = nil
},
)
go func() {
for packet := range slf.accumulate {
slf.packets <- packet
}
}()
defer func() {
if err := recover(); err != nil {
err, isErr := err.(error)
if !isErr {
err = fmt.Errorf("%v", err)
}
slf.Close(err)
slf.packets = nil
}
}()
wait.Done()
slf.mutex.Unlock()
for {
packet, ok := <-slf.packets
if !ok {
slf.mutex.Lock()
slf.packets = nil
slf.mutex.Unlock()
break
}
data := packet
var err = slf.core.Write(data)
callback := data.callback
slf.packetPool.Release(data)
if callback != nil {
callback(err)
}
if err != nil {
panic(err)
}
} }
cp := slf.pool.Get()
cp.wst = wst
cp.data = packet
if len(callback) > 0 {
cp.callback = callback[0]
}
slf.loop.Put(cp)
} }
func (slf *Client) onReceive(wst int, packet []byte) { func (slf *Client) onReceive(wst int, packet []byte) {
@ -197,8 +139,3 @@ func (slf *Client) onReceive(wst int, packet []byte) {
func (slf *Client) GetServerAddr() string { func (slf *Client) GetServerAddr() string {
return slf.core.GetServerAddr() return slf.core.GetServerAddr()
} }
// GetMessageAccumulationTotal 获取消息积压总数
func (slf *Client) GetMessageAccumulationTotal() int {
return slf.accumulation
}

View File

@ -3,6 +3,7 @@ package client
import ( import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server"
"sync"
) )
// NewWebsocket 创建 websocket 客户端 // NewWebsocket 创建 websocket 客户端
@ -17,6 +18,7 @@ type Websocket struct {
addr string addr string
conn *websocket.Conn conn *websocket.Conn
closed bool closed bool
mu sync.Mutex
} }
func (slf *Websocket) Run(runState chan<- error, receive func(wst int, packet []byte)) { func (slf *Websocket) Run(runState chan<- error, receive func(wst int, packet []byte)) {
@ -28,7 +30,13 @@ func (slf *Websocket) Run(runState chan<- error, receive func(wst int, packet []
slf.conn = ws slf.conn = ws
slf.closed = false slf.closed = false
runState <- nil runState <- nil
for !slf.closed { for {
slf.mu.Lock()
if slf.closed {
slf.mu.Unlock()
break
}
slf.mu.Unlock()
messageType, packet, readErr := ws.ReadMessage() messageType, packet, readErr := ws.ReadMessage()
if readErr != nil { if readErr != nil {
panic(readErr) panic(readErr)
@ -45,6 +53,8 @@ func (slf *Websocket) Write(packet *Packet) error {
} }
func (slf *Websocket) Close() { func (slf *Websocket) Close() {
slf.mu.Lock()
defer slf.mu.Unlock()
slf.closed = true slf.closed = true
} }