refactor: client 包采用无界缓冲区替代通过 chan 实现的写通道,移除消息堆积功能,优化代码逻辑
This commit is contained in:
parent
dd1acfd017
commit
2d9ffad2ab
|
@ -1,7 +1,9 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/kercylan98/minotaur/server/writeloop"
|
||||||
"github.com/kercylan98/minotaur/utils/concurrent"
|
"github.com/kercylan98/minotaur/utils/concurrent"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
@ -11,28 +13,36 @@ func NewClient(core Core) *Client {
|
||||||
client := &Client{
|
client := &Client{
|
||||||
events: new(events),
|
events: new(events),
|
||||||
core: core,
|
core: core,
|
||||||
|
closed: true,
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloneClient 克隆客户端
|
// CloneClient 克隆客户端
|
||||||
func CloneClient(client *Client) *Client {
|
func CloneClient(client *Client) *Client {
|
||||||
return NewClient(client.core.Clone())
|
cli := NewClient(client.core.Clone())
|
||||||
|
return cli
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client 客户端
|
// Client 客户端
|
||||||
type Client struct {
|
type Client struct {
|
||||||
*events
|
*events
|
||||||
core Core
|
core Core
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
packetPool *concurrent.Pool[*Packet]
|
closed bool // 是否已关闭
|
||||||
packets chan *Packet
|
pool *concurrent.Pool[*Packet] // 数据包缓冲池
|
||||||
|
loop *writeloop.WriteLoop[*Packet] // 写入循环
|
||||||
accumulate chan *Packet
|
|
||||||
accumulation int // 积压消息数
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run 运行客户端
|
||||||
|
// - 当客户端已运行时,会先关闭客户端再重新运行
|
||||||
func (slf *Client) Run() error {
|
func (slf *Client) Run() error {
|
||||||
|
slf.mutex.Lock()
|
||||||
|
if !slf.closed {
|
||||||
|
slf.mutex.Unlock()
|
||||||
|
slf.Close()
|
||||||
|
slf.mutex.Lock()
|
||||||
|
}
|
||||||
var runState = make(chan error)
|
var runState = make(chan error)
|
||||||
go func(runState chan<- error) {
|
go func(runState chan<- error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -44,49 +54,46 @@ func (slf *Client) Run() error {
|
||||||
}(runState)
|
}(runState)
|
||||||
err := <-runState
|
err := <-runState
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slf.mutex.Lock()
|
|
||||||
if slf.packetPool != nil {
|
|
||||||
slf.packetPool.Close()
|
|
||||||
slf.packetPool = nil
|
|
||||||
}
|
|
||||||
slf.mutex.Unlock()
|
slf.mutex.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var wait = new(sync.WaitGroup)
|
slf.closed = false
|
||||||
wait.Add(1)
|
slf.pool = concurrent.NewPool[*Packet](10*1024, func() *Packet {
|
||||||
go slf.writeLoop(wait)
|
return new(Packet)
|
||||||
wait.Wait()
|
}, func(data *Packet) {
|
||||||
|
data.wst = 0
|
||||||
|
data.data = nil
|
||||||
|
data.callback = nil
|
||||||
|
})
|
||||||
|
slf.loop = writeloop.NewWriteLoop[*Packet](slf.pool, func(message *Packet) error {
|
||||||
|
err := slf.core.Write(message)
|
||||||
|
if message.callback != nil {
|
||||||
|
message.callback(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}, func(err any) {
|
||||||
|
slf.Close(errors.New(fmt.Sprint(err)))
|
||||||
|
})
|
||||||
|
slf.mutex.Unlock()
|
||||||
|
|
||||||
slf.OnConnectionOpenedEvent(slf)
|
slf.OnConnectionOpenedEvent(slf)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsConnected 是否已连接
|
// IsConnected 是否已连接
|
||||||
func (slf *Client) IsConnected() bool {
|
func (slf *Client) IsConnected() bool {
|
||||||
return slf.packetPool != nil
|
slf.mutex.Lock()
|
||||||
|
defer slf.mutex.Unlock()
|
||||||
|
return !slf.closed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close 关闭
|
// Close 关闭
|
||||||
func (slf *Client) Close(err ...error) {
|
func (slf *Client) Close(err ...error) {
|
||||||
slf.mutex.Lock()
|
slf.mutex.Lock()
|
||||||
var unlock bool
|
slf.closed = true
|
||||||
defer func() {
|
|
||||||
if !unlock {
|
|
||||||
slf.mutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
slf.core.Close()
|
slf.core.Close()
|
||||||
if slf.packetPool != nil {
|
slf.loop.Close()
|
||||||
slf.packetPool.Close()
|
slf.pool.Close()
|
||||||
slf.packetPool = nil
|
|
||||||
}
|
|
||||||
if slf.packets != nil {
|
|
||||||
close(slf.packets)
|
|
||||||
}
|
|
||||||
if slf.accumulate != nil {
|
|
||||||
close(slf.accumulate)
|
|
||||||
slf.accumulate = nil
|
|
||||||
}
|
|
||||||
unlock = true
|
|
||||||
slf.mutex.Unlock()
|
slf.mutex.Unlock()
|
||||||
if len(err) > 0 {
|
if len(err) > 0 {
|
||||||
slf.OnConnectionClosedEvent(slf, err[0])
|
slf.OnConnectionClosedEvent(slf, err[0])
|
||||||
|
@ -110,83 +117,18 @@ func (slf *Client) Write(packet []byte, callback ...func(err error)) {
|
||||||
// - messageType: websocket模式中指定消息类型
|
// - messageType: websocket模式中指定消息类型
|
||||||
func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) {
|
func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) {
|
||||||
slf.mutex.Lock()
|
slf.mutex.Lock()
|
||||||
if slf.packetPool == nil || slf.packets == nil {
|
defer slf.mutex.Unlock()
|
||||||
var p = &Packet{
|
if slf.closed {
|
||||||
wst: wst,
|
return
|
||||||
data: packet,
|
|
||||||
}
|
|
||||||
if len(callback) > 0 {
|
|
||||||
p.callback = callback[0]
|
|
||||||
}
|
|
||||||
if slf.accumulate == nil {
|
|
||||||
slf.accumulate = make(chan *Packet, 1024*10)
|
|
||||||
}
|
|
||||||
slf.accumulate <- p
|
|
||||||
} else {
|
|
||||||
cp := slf.packetPool.Get()
|
|
||||||
cp.wst = wst
|
|
||||||
cp.data = packet
|
|
||||||
if len(callback) > 0 {
|
|
||||||
cp.callback = callback[0]
|
|
||||||
}
|
|
||||||
slf.packets <- cp
|
|
||||||
slf.accumulation = len(slf.accumulate) + len(slf.packets)
|
|
||||||
}
|
|
||||||
slf.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeLoop 写循环
|
|
||||||
func (slf *Client) writeLoop(wait *sync.WaitGroup) {
|
|
||||||
slf.mutex.Lock()
|
|
||||||
slf.packets = make(chan *Packet, 1024*10)
|
|
||||||
slf.packetPool = concurrent.NewPool[*Packet](10*1024,
|
|
||||||
func() *Packet {
|
|
||||||
return &Packet{}
|
|
||||||
}, func(data *Packet) {
|
|
||||||
data.wst = 0
|
|
||||||
data.data = nil
|
|
||||||
data.callback = nil
|
|
||||||
},
|
|
||||||
)
|
|
||||||
go func() {
|
|
||||||
for packet := range slf.accumulate {
|
|
||||||
slf.packets <- packet
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
err, isErr := err.(error)
|
|
||||||
if !isErr {
|
|
||||||
err = fmt.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
slf.Close(err)
|
|
||||||
slf.packets = nil
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
wait.Done()
|
|
||||||
slf.mutex.Unlock()
|
|
||||||
|
|
||||||
for {
|
|
||||||
packet, ok := <-slf.packets
|
|
||||||
if !ok {
|
|
||||||
slf.mutex.Lock()
|
|
||||||
slf.packets = nil
|
|
||||||
slf.mutex.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
data := packet
|
|
||||||
var err = slf.core.Write(data)
|
|
||||||
callback := data.callback
|
|
||||||
slf.packetPool.Release(data)
|
|
||||||
if callback != nil {
|
|
||||||
callback(err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cp := slf.pool.Get()
|
||||||
|
cp.wst = wst
|
||||||
|
cp.data = packet
|
||||||
|
if len(callback) > 0 {
|
||||||
|
cp.callback = callback[0]
|
||||||
|
}
|
||||||
|
slf.loop.Put(cp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (slf *Client) onReceive(wst int, packet []byte) {
|
func (slf *Client) onReceive(wst int, packet []byte) {
|
||||||
|
@ -197,8 +139,3 @@ func (slf *Client) onReceive(wst int, packet []byte) {
|
||||||
func (slf *Client) GetServerAddr() string {
|
func (slf *Client) GetServerAddr() string {
|
||||||
return slf.core.GetServerAddr()
|
return slf.core.GetServerAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessageAccumulationTotal 获取消息积压总数
|
|
||||||
func (slf *Client) GetMessageAccumulationTotal() int {
|
|
||||||
return slf.accumulation
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package client
|
||||||
import (
|
import (
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/kercylan98/minotaur/server"
|
"github.com/kercylan98/minotaur/server"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWebsocket 创建 websocket 客户端
|
// NewWebsocket 创建 websocket 客户端
|
||||||
|
@ -17,6 +18,7 @@ type Websocket struct {
|
||||||
addr string
|
addr string
|
||||||
conn *websocket.Conn
|
conn *websocket.Conn
|
||||||
closed bool
|
closed bool
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (slf *Websocket) Run(runState chan<- error, receive func(wst int, packet []byte)) {
|
func (slf *Websocket) Run(runState chan<- error, receive func(wst int, packet []byte)) {
|
||||||
|
@ -28,7 +30,13 @@ func (slf *Websocket) Run(runState chan<- error, receive func(wst int, packet []
|
||||||
slf.conn = ws
|
slf.conn = ws
|
||||||
slf.closed = false
|
slf.closed = false
|
||||||
runState <- nil
|
runState <- nil
|
||||||
for !slf.closed {
|
for {
|
||||||
|
slf.mu.Lock()
|
||||||
|
if slf.closed {
|
||||||
|
slf.mu.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
slf.mu.Unlock()
|
||||||
messageType, packet, readErr := ws.ReadMessage()
|
messageType, packet, readErr := ws.ReadMessage()
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
panic(readErr)
|
panic(readErr)
|
||||||
|
@ -45,6 +53,8 @@ func (slf *Websocket) Write(packet *Packet) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (slf *Websocket) Close() {
|
func (slf *Websocket) Close() {
|
||||||
|
slf.mu.Lock()
|
||||||
|
defer slf.mu.Unlock()
|
||||||
slf.closed = true
|
slf.closed = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue