fix: 修复 server.Conn 和 client.Client 连接关闭时发生的竞态问题
This commit is contained in:
parent
8fd4e8f722
commit
0215c5449a
|
@ -34,14 +34,14 @@ type Client struct {
|
|||
|
||||
func (slf *Client) Run() error {
|
||||
var runState = make(chan error)
|
||||
go func() {
|
||||
go func(runState chan<- error) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
slf.Close(err.(error))
|
||||
}
|
||||
}()
|
||||
slf.core.Run(runState, slf.onReceive)
|
||||
}()
|
||||
}(runState)
|
||||
err := <-runState
|
||||
if err != nil {
|
||||
slf.mutex.Lock()
|
||||
|
@ -81,13 +81,11 @@ func (slf *Client) Close(err ...error) {
|
|||
}
|
||||
if slf.packets != nil {
|
||||
close(slf.packets)
|
||||
slf.packets = nil
|
||||
}
|
||||
if slf.accumulate != nil {
|
||||
close(slf.accumulate)
|
||||
slf.accumulate = nil
|
||||
}
|
||||
slf.packets = nil
|
||||
unlock = true
|
||||
slf.mutex.Unlock()
|
||||
if len(err) > 0 {
|
||||
|
@ -139,6 +137,7 @@ func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) {
|
|||
|
||||
// writeLoop 写循环
|
||||
func (slf *Client) writeLoop(wait *sync.WaitGroup) {
|
||||
slf.mutex.Lock()
|
||||
slf.packets = make(chan *Packet, 1024*10)
|
||||
slf.packetPool = concurrent.NewPool[*Packet](10*1024,
|
||||
func() *Packet {
|
||||
|
@ -161,11 +160,21 @@ func (slf *Client) writeLoop(wait *sync.WaitGroup) {
|
|||
err = fmt.Errorf("%v", err)
|
||||
}
|
||||
slf.Close(err)
|
||||
slf.packets = nil
|
||||
}
|
||||
}()
|
||||
wait.Done()
|
||||
slf.mutex.Unlock()
|
||||
|
||||
for {
|
||||
packet, ok := <-slf.packets
|
||||
if !ok {
|
||||
slf.mutex.Lock()
|
||||
slf.packets = nil
|
||||
slf.mutex.Unlock()
|
||||
break
|
||||
}
|
||||
|
||||
for packet := range slf.packets {
|
||||
data := packet
|
||||
var err = slf.core.Write(data)
|
||||
callback := data.callback
|
||||
|
|
|
@ -17,7 +17,6 @@ func newKcpConn(server *Server, session *kcp.UDPSession) *Conn {
|
|||
ctx: server.ctx,
|
||||
connection: &connection{
|
||||
packets: make(chan *connPacket, 1024*10),
|
||||
mutex: new(sync.Mutex),
|
||||
server: server,
|
||||
remoteAddr: session.RemoteAddr(),
|
||||
ip: session.RemoteAddr().String(),
|
||||
|
@ -41,7 +40,6 @@ func newGNetConn(server *Server, conn gnet.Conn) *Conn {
|
|||
ctx: server.ctx,
|
||||
connection: &connection{
|
||||
packets: make(chan *connPacket, 1024*10),
|
||||
mutex: new(sync.Mutex),
|
||||
server: server,
|
||||
remoteAddr: conn.RemoteAddr(),
|
||||
ip: conn.RemoteAddr().String(),
|
||||
|
@ -65,7 +63,6 @@ func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn {
|
|||
ctx: server.ctx,
|
||||
connection: &connection{
|
||||
packets: make(chan *connPacket, 1024*10),
|
||||
mutex: new(sync.Mutex),
|
||||
server: server,
|
||||
remoteAddr: ws.RemoteAddr(),
|
||||
ip: ip,
|
||||
|
@ -86,7 +83,6 @@ func newGatewayConn(conn *Conn, connId string) *Conn {
|
|||
//ctx: server.ctx,
|
||||
connection: &connection{
|
||||
packets: make(chan *connPacket, 1024*10),
|
||||
mutex: new(sync.Mutex),
|
||||
server: conn.server,
|
||||
data: map[any]any{},
|
||||
},
|
||||
|
@ -103,7 +99,6 @@ func NewEmptyConn(server *Server) *Conn {
|
|||
ctx: server.ctx,
|
||||
connection: &connection{
|
||||
packets: make(chan *connPacket, 1024*10),
|
||||
mutex: new(sync.Mutex),
|
||||
server: server,
|
||||
remoteAddr: &net.TCPAddr{},
|
||||
ip: "0.0.0.0:0",
|
||||
|
@ -126,9 +121,9 @@ type Conn struct {
|
|||
// connection 长久保持的连接
|
||||
type connection struct {
|
||||
server *Server
|
||||
mutex *sync.Mutex
|
||||
close sync.Once
|
||||
closed bool
|
||||
closeL sync.Mutex
|
||||
remoteAddr net.Addr
|
||||
ip string
|
||||
ws *websocket.Conn
|
||||
|
@ -166,33 +161,6 @@ func (slf *Conn) IsClosed() bool {
|
|||
return slf.closed
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (slf *Conn) Close(err ...error) {
|
||||
slf.close.Do(func() {
|
||||
slf.closed = true
|
||||
if slf.ws != nil {
|
||||
_ = slf.ws.Close()
|
||||
} else if slf.gn != nil {
|
||||
_ = slf.gn.Close()
|
||||
} else if slf.kcp != nil {
|
||||
_ = slf.kcp.Close()
|
||||
}
|
||||
if slf.packetPool != nil {
|
||||
slf.packetPool.Close()
|
||||
}
|
||||
slf.packetPool = nil
|
||||
if slf.packets != nil {
|
||||
close(slf.packets)
|
||||
slf.packets = nil
|
||||
}
|
||||
if len(err) > 0 {
|
||||
slf.server.OnConnectionClosedEvent(slf, err[0])
|
||||
return
|
||||
}
|
||||
slf.server.OnConnectionClosedEvent(slf, nil)
|
||||
})
|
||||
}
|
||||
|
||||
// SetData 设置连接数据,该数据将在连接关闭前始终存在
|
||||
func (slf *Conn) SetData(key, value any) *Conn {
|
||||
slf.data[key] = value
|
||||
|
@ -243,13 +211,13 @@ func (slf *Conn) SetWST(wst int) *Conn {
|
|||
// Write 向连接中写入数据
|
||||
// - messageType: websocket模式中指定消息类型
|
||||
func (slf *Conn) Write(packet []byte, callback ...func(err error)) {
|
||||
slf.mutex.Lock()
|
||||
defer slf.mutex.Unlock()
|
||||
if slf.gw != nil {
|
||||
slf.gw(packet)
|
||||
return
|
||||
}
|
||||
packet = slf.server.OnConnectionWritePacketBeforeEvent(slf, packet)
|
||||
slf.closeL.Lock()
|
||||
defer slf.closeL.Unlock()
|
||||
if slf.packetPool == nil || slf.packets == nil {
|
||||
return
|
||||
}
|
||||
|
@ -276,10 +244,16 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) {
|
|||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
slf.Close()
|
||||
slf.packets = nil
|
||||
}
|
||||
}()
|
||||
wait.Done()
|
||||
for packet := range slf.packets {
|
||||
for {
|
||||
packet, ok := <-slf.packets
|
||||
if !ok {
|
||||
slf.packets = nil
|
||||
break
|
||||
}
|
||||
|
||||
data := packet
|
||||
var err error
|
||||
|
@ -298,8 +272,9 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) {
|
|||
}
|
||||
}
|
||||
callback := data.callback
|
||||
slf.closeL.Lock()
|
||||
slf.packetPool.Release(data)
|
||||
|
||||
slf.closeL.Unlock()
|
||||
if callback != nil {
|
||||
callback(err)
|
||||
}
|
||||
|
@ -308,3 +283,31 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (slf *Conn) Close(err ...error) {
|
||||
slf.close.Do(func() {
|
||||
slf.closeL.Lock()
|
||||
defer slf.closeL.Unlock()
|
||||
slf.closed = true
|
||||
if slf.ws != nil {
|
||||
_ = slf.ws.Close()
|
||||
} else if slf.gn != nil {
|
||||
_ = slf.gn.Close()
|
||||
} else if slf.kcp != nil {
|
||||
_ = slf.kcp.Close()
|
||||
}
|
||||
if slf.packetPool != nil {
|
||||
slf.packetPool.Close()
|
||||
}
|
||||
slf.packetPool = nil
|
||||
if slf.packets != nil {
|
||||
close(slf.packets)
|
||||
}
|
||||
if len(err) > 0 {
|
||||
slf.server.OnConnectionClosedEvent(slf, err[0])
|
||||
return
|
||||
}
|
||||
slf.server.OnConnectionClosedEvent(slf, nil)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"github.com/kercylan98/minotaur/server"
|
||||
"github.com/kercylan98/minotaur/server/client"
|
||||
"github.com/kercylan98/minotaur/utils/log"
|
||||
"go.uber.org/atomic"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
@ -38,7 +39,7 @@ type Endpoint struct {
|
|||
client []*client.Client // 端点客户端
|
||||
name string // 端点名称
|
||||
address string // 端点地址
|
||||
state float64 // 端点健康值(0为不可用,越高越优)
|
||||
state atomic.Float64 // 端点健康值(0为不可用,越高越优)
|
||||
evaluator func(costUnixNano float64) float64 // 端点健康值评估函数
|
||||
connections *haxmap.Map[string, *server.Conn] // 被该端点转发的连接列表
|
||||
rci time.Duration // 端点重连间隔
|
||||
|
@ -50,13 +51,13 @@ func (slf *Endpoint) start(gateway *Gateway, cli *client.Client) {
|
|||
for {
|
||||
cur := time.Now().UnixNano()
|
||||
if err := cli.Run(); err == nil {
|
||||
slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur))
|
||||
slf.state.Swap(slf.evaluator(float64(time.Now().UnixNano() - cur)))
|
||||
break
|
||||
}
|
||||
if slf.rci > 0 {
|
||||
time.Sleep(slf.rci)
|
||||
} else {
|
||||
slf.state = 0
|
||||
slf.state.Swap(0)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -83,7 +84,7 @@ func (slf *Endpoint) connect(gateway *Gateway) {
|
|||
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.Err(err))
|
||||
return
|
||||
}
|
||||
slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime))
|
||||
slf.state.Swap(slf.evaluator(float64(time.Now().UnixNano() - sendTime)))
|
||||
c, ok := slf.connections.Get(addr)
|
||||
if !ok {
|
||||
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.String("ConnAddr", addr), log.Err(ErrConnectionNotFount))
|
||||
|
@ -111,7 +112,7 @@ func (slf *Endpoint) GetAddress() string {
|
|||
|
||||
// GetState 获取端点健康值
|
||||
func (slf *Endpoint) GetState() float64 {
|
||||
return slf.state
|
||||
return slf.state.Load()
|
||||
}
|
||||
|
||||
// Forward 转发数据包到该端点
|
||||
|
|
|
@ -127,7 +127,7 @@ func (slf *Gateway) GetEndpoint(name string) (*Endpoint, error) {
|
|||
|
||||
var available = make([]*Endpoint, 0, len(endpoints))
|
||||
for _, e := range endpoints {
|
||||
if e.state > 0 {
|
||||
if e.GetState() > 0 {
|
||||
available = append(available, e)
|
||||
}
|
||||
}
|
||||
|
@ -150,7 +150,7 @@ func (slf *Gateway) GetConnEndpoint(name string, conn *server.Conn) (*Endpoint,
|
|||
slf.cceLock.RLock()
|
||||
endpoint, exist := slf.cce[conn.GetID()]
|
||||
slf.cceLock.RUnlock()
|
||||
if exist && endpoint.state > 0 {
|
||||
if exist && endpoint.GetState() > 0 {
|
||||
return endpoint, nil
|
||||
}
|
||||
return slf.GetEndpoint(name)
|
||||
|
@ -158,7 +158,7 @@ func (slf *Gateway) GetConnEndpoint(name string, conn *server.Conn) (*Endpoint,
|
|||
|
||||
// SwitchEndpoint 将端点端点的所有连接切换到另一个端点
|
||||
func (slf *Gateway) SwitchEndpoint(source, dest *Endpoint) {
|
||||
if source.name == dest.name && source.address == dest.address || source.state <= 0 || dest.state <= 0 {
|
||||
if source.name == dest.name && source.address == dest.address || source.GetState() <= 0 || dest.GetState() <= 0 {
|
||||
return
|
||||
}
|
||||
slf.cceLock.Lock()
|
||||
|
|
Loading…
Reference in New Issue