fix: 修复 server.Conn 和 client.Client 连接关闭时发生的竞态问题

This commit is contained in:
kercylan98 2023-09-09 13:10:19 +08:00
parent 8fd4e8f722
commit 0215c5449a
4 changed files with 63 additions and 50 deletions

View File

@ -34,14 +34,14 @@ type Client struct {
func (slf *Client) Run() error {
var runState = make(chan error)
go func() {
go func(runState chan<- error) {
defer func() {
if err := recover(); err != nil {
slf.Close(err.(error))
}
}()
slf.core.Run(runState, slf.onReceive)
}()
}(runState)
err := <-runState
if err != nil {
slf.mutex.Lock()
@ -81,13 +81,11 @@ func (slf *Client) Close(err ...error) {
}
if slf.packets != nil {
close(slf.packets)
slf.packets = nil
}
if slf.accumulate != nil {
close(slf.accumulate)
slf.accumulate = nil
}
slf.packets = nil
unlock = true
slf.mutex.Unlock()
if len(err) > 0 {
@ -139,6 +137,7 @@ func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) {
// writeLoop 写循环
func (slf *Client) writeLoop(wait *sync.WaitGroup) {
slf.mutex.Lock()
slf.packets = make(chan *Packet, 1024*10)
slf.packetPool = concurrent.NewPool[*Packet](10*1024,
func() *Packet {
@ -161,11 +160,21 @@ func (slf *Client) writeLoop(wait *sync.WaitGroup) {
err = fmt.Errorf("%v", err)
}
slf.Close(err)
slf.packets = nil
}
}()
wait.Done()
slf.mutex.Unlock()
for {
packet, ok := <-slf.packets
if !ok {
slf.mutex.Lock()
slf.packets = nil
slf.mutex.Unlock()
break
}
for packet := range slf.packets {
data := packet
var err = slf.core.Write(data)
callback := data.callback

View File

@ -17,7 +17,6 @@ func newKcpConn(server *Server, session *kcp.UDPSession) *Conn {
ctx: server.ctx,
connection: &connection{
packets: make(chan *connPacket, 1024*10),
mutex: new(sync.Mutex),
server: server,
remoteAddr: session.RemoteAddr(),
ip: session.RemoteAddr().String(),
@ -41,7 +40,6 @@ func newGNetConn(server *Server, conn gnet.Conn) *Conn {
ctx: server.ctx,
connection: &connection{
packets: make(chan *connPacket, 1024*10),
mutex: new(sync.Mutex),
server: server,
remoteAddr: conn.RemoteAddr(),
ip: conn.RemoteAddr().String(),
@ -65,7 +63,6 @@ func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn {
ctx: server.ctx,
connection: &connection{
packets: make(chan *connPacket, 1024*10),
mutex: new(sync.Mutex),
server: server,
remoteAddr: ws.RemoteAddr(),
ip: ip,
@ -86,7 +83,6 @@ func newGatewayConn(conn *Conn, connId string) *Conn {
//ctx: server.ctx,
connection: &connection{
packets: make(chan *connPacket, 1024*10),
mutex: new(sync.Mutex),
server: conn.server,
data: map[any]any{},
},
@ -103,7 +99,6 @@ func NewEmptyConn(server *Server) *Conn {
ctx: server.ctx,
connection: &connection{
packets: make(chan *connPacket, 1024*10),
mutex: new(sync.Mutex),
server: server,
remoteAddr: &net.TCPAddr{},
ip: "0.0.0.0:0",
@ -126,9 +121,9 @@ type Conn struct {
// connection 长久保持的连接
type connection struct {
server *Server
mutex *sync.Mutex
close sync.Once
closed bool
closeL sync.Mutex
remoteAddr net.Addr
ip string
ws *websocket.Conn
@ -166,33 +161,6 @@ func (slf *Conn) IsClosed() bool {
return slf.closed
}
// Close 关闭连接
func (slf *Conn) Close(err ...error) {
slf.close.Do(func() {
slf.closed = true
if slf.ws != nil {
_ = slf.ws.Close()
} else if slf.gn != nil {
_ = slf.gn.Close()
} else if slf.kcp != nil {
_ = slf.kcp.Close()
}
if slf.packetPool != nil {
slf.packetPool.Close()
}
slf.packetPool = nil
if slf.packets != nil {
close(slf.packets)
slf.packets = nil
}
if len(err) > 0 {
slf.server.OnConnectionClosedEvent(slf, err[0])
return
}
slf.server.OnConnectionClosedEvent(slf, nil)
})
}
// SetData 设置连接数据,该数据将在连接关闭前始终存在
func (slf *Conn) SetData(key, value any) *Conn {
slf.data[key] = value
@ -243,13 +211,13 @@ func (slf *Conn) SetWST(wst int) *Conn {
// Write 向连接中写入数据
// - messageType: websocket模式中指定消息类型
func (slf *Conn) Write(packet []byte, callback ...func(err error)) {
slf.mutex.Lock()
defer slf.mutex.Unlock()
if slf.gw != nil {
slf.gw(packet)
return
}
packet = slf.server.OnConnectionWritePacketBeforeEvent(slf, packet)
slf.closeL.Lock()
defer slf.closeL.Unlock()
if slf.packetPool == nil || slf.packets == nil {
return
}
@ -276,10 +244,16 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) {
defer func() {
if err := recover(); err != nil {
slf.Close()
slf.packets = nil
}
}()
wait.Done()
for packet := range slf.packets {
for {
packet, ok := <-slf.packets
if !ok {
slf.packets = nil
break
}
data := packet
var err error
@ -298,8 +272,9 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) {
}
}
callback := data.callback
slf.closeL.Lock()
slf.packetPool.Release(data)
slf.closeL.Unlock()
if callback != nil {
callback(err)
}
@ -308,3 +283,31 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) {
}
}
}
// Close 关闭连接
func (slf *Conn) Close(err ...error) {
slf.close.Do(func() {
slf.closeL.Lock()
defer slf.closeL.Unlock()
slf.closed = true
if slf.ws != nil {
_ = slf.ws.Close()
} else if slf.gn != nil {
_ = slf.gn.Close()
} else if slf.kcp != nil {
_ = slf.kcp.Close()
}
if slf.packetPool != nil {
slf.packetPool.Close()
}
slf.packetPool = nil
if slf.packets != nil {
close(slf.packets)
}
if len(err) > 0 {
slf.server.OnConnectionClosedEvent(slf, err[0])
return
}
slf.server.OnConnectionClosedEvent(slf, nil)
})
}

View File

@ -5,6 +5,7 @@ import (
"github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/server/client"
"github.com/kercylan98/minotaur/utils/log"
"go.uber.org/atomic"
"sync"
"time"
)
@ -38,7 +39,7 @@ type Endpoint struct {
client []*client.Client // 端点客户端
name string // 端点名称
address string // 端点地址
state float64 // 端点健康值0为不可用越高越优
state atomic.Float64 // 端点健康值0为不可用越高越优
evaluator func(costUnixNano float64) float64 // 端点健康值评估函数
connections *haxmap.Map[string, *server.Conn] // 被该端点转发的连接列表
rci time.Duration // 端点重连间隔
@ -50,13 +51,13 @@ func (slf *Endpoint) start(gateway *Gateway, cli *client.Client) {
for {
cur := time.Now().UnixNano()
if err := cli.Run(); err == nil {
slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur))
slf.state.Swap(slf.evaluator(float64(time.Now().UnixNano() - cur)))
break
}
if slf.rci > 0 {
time.Sleep(slf.rci)
} else {
slf.state = 0
slf.state.Swap(0)
break
}
}
@ -83,7 +84,7 @@ func (slf *Endpoint) connect(gateway *Gateway) {
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.Err(err))
return
}
slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime))
slf.state.Swap(slf.evaluator(float64(time.Now().UnixNano() - sendTime)))
c, ok := slf.connections.Get(addr)
if !ok {
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.String("ConnAddr", addr), log.Err(ErrConnectionNotFount))
@ -111,7 +112,7 @@ func (slf *Endpoint) GetAddress() string {
// GetState 获取端点健康值
func (slf *Endpoint) GetState() float64 {
return slf.state
return slf.state.Load()
}
// Forward 转发数据包到该端点

View File

@ -127,7 +127,7 @@ func (slf *Gateway) GetEndpoint(name string) (*Endpoint, error) {
var available = make([]*Endpoint, 0, len(endpoints))
for _, e := range endpoints {
if e.state > 0 {
if e.GetState() > 0 {
available = append(available, e)
}
}
@ -150,7 +150,7 @@ func (slf *Gateway) GetConnEndpoint(name string, conn *server.Conn) (*Endpoint,
slf.cceLock.RLock()
endpoint, exist := slf.cce[conn.GetID()]
slf.cceLock.RUnlock()
if exist && endpoint.state > 0 {
if exist && endpoint.GetState() > 0 {
return endpoint, nil
}
return slf.GetEndpoint(name)
@ -158,7 +158,7 @@ func (slf *Gateway) GetConnEndpoint(name string, conn *server.Conn) (*Endpoint,
// SwitchEndpoint 将端点端点的所有连接切换到另一个端点
func (slf *Gateway) SwitchEndpoint(source, dest *Endpoint) {
if source.name == dest.name && source.address == dest.address || source.state <= 0 || dest.state <= 0 {
if source.name == dest.name && source.address == dest.address || source.GetState() <= 0 || dest.GetState() <= 0 {
return
}
slf.cceLock.Lock()