diff --git a/server/client/client.go b/server/client/client.go index a1d82cf..cc42505 100644 --- a/server/client/client.go +++ b/server/client/client.go @@ -34,14 +34,14 @@ type Client struct { func (slf *Client) Run() error { var runState = make(chan error) - go func() { + go func(runState chan<- error) { defer func() { if err := recover(); err != nil { slf.Close(err.(error)) } }() slf.core.Run(runState, slf.onReceive) - }() + }(runState) err := <-runState if err != nil { slf.mutex.Lock() @@ -81,13 +81,11 @@ func (slf *Client) Close(err ...error) { } if slf.packets != nil { close(slf.packets) - slf.packets = nil } if slf.accumulate != nil { close(slf.accumulate) slf.accumulate = nil } - slf.packets = nil unlock = true slf.mutex.Unlock() if len(err) > 0 { @@ -139,6 +137,7 @@ func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) { // writeLoop 写循环 func (slf *Client) writeLoop(wait *sync.WaitGroup) { + slf.mutex.Lock() slf.packets = make(chan *Packet, 1024*10) slf.packetPool = concurrent.NewPool[*Packet](10*1024, func() *Packet { @@ -161,11 +160,21 @@ func (slf *Client) writeLoop(wait *sync.WaitGroup) { err = fmt.Errorf("%v", err) } slf.Close(err) + slf.packets = nil } }() wait.Done() + slf.mutex.Unlock() + + for { + packet, ok := <-slf.packets + if !ok { + slf.mutex.Lock() + slf.packets = nil + slf.mutex.Unlock() + break + } - for packet := range slf.packets { data := packet var err = slf.core.Write(data) callback := data.callback diff --git a/server/conn.go b/server/conn.go index afa73d7..d6d947a 100644 --- a/server/conn.go +++ b/server/conn.go @@ -17,7 +17,6 @@ func newKcpConn(server *Server, session *kcp.UDPSession) *Conn { ctx: server.ctx, connection: &connection{ packets: make(chan *connPacket, 1024*10), - mutex: new(sync.Mutex), server: server, remoteAddr: session.RemoteAddr(), ip: session.RemoteAddr().String(), @@ -41,7 +40,6 @@ func newGNetConn(server *Server, conn gnet.Conn) *Conn { ctx: server.ctx, connection: &connection{ packets: make(chan *connPacket, 1024*10), - mutex: new(sync.Mutex), server: server, remoteAddr: conn.RemoteAddr(), ip: conn.RemoteAddr().String(), @@ -65,7 +63,6 @@ func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn { ctx: server.ctx, connection: &connection{ packets: make(chan *connPacket, 1024*10), - mutex: new(sync.Mutex), server: server, remoteAddr: ws.RemoteAddr(), ip: ip, @@ -86,7 +83,6 @@ func newGatewayConn(conn *Conn, connId string) *Conn { //ctx: server.ctx, connection: &connection{ packets: make(chan *connPacket, 1024*10), - mutex: new(sync.Mutex), server: conn.server, data: map[any]any{}, }, @@ -103,7 +99,6 @@ func NewEmptyConn(server *Server) *Conn { ctx: server.ctx, connection: &connection{ packets: make(chan *connPacket, 1024*10), - mutex: new(sync.Mutex), server: server, remoteAddr: &net.TCPAddr{}, ip: "0.0.0.0:0", @@ -126,9 +121,9 @@ type Conn struct { // connection 长久保持的连接 type connection struct { server *Server - mutex *sync.Mutex close sync.Once closed bool + closeL sync.Mutex remoteAddr net.Addr ip string ws *websocket.Conn @@ -166,33 +161,6 @@ func (slf *Conn) IsClosed() bool { return slf.closed } -// Close 关闭连接 -func (slf *Conn) Close(err ...error) { - slf.close.Do(func() { - slf.closed = true - if slf.ws != nil { - _ = slf.ws.Close() - } else if slf.gn != nil { - _ = slf.gn.Close() - } else if slf.kcp != nil { - _ = slf.kcp.Close() - } - if slf.packetPool != nil { - slf.packetPool.Close() - } - slf.packetPool = nil - if slf.packets != nil { - close(slf.packets) - slf.packets = nil - } - if len(err) > 0 { - slf.server.OnConnectionClosedEvent(slf, err[0]) - return - } - slf.server.OnConnectionClosedEvent(slf, nil) - }) -} - // SetData 设置连接数据,该数据将在连接关闭前始终存在 func (slf *Conn) SetData(key, value any) *Conn { slf.data[key] = value @@ -243,13 +211,13 @@ func (slf *Conn) SetWST(wst int) *Conn { // Write 向连接中写入数据 // - messageType: websocket模式中指定消息类型 func (slf *Conn) Write(packet []byte, callback ...func(err error)) { - slf.mutex.Lock() - defer slf.mutex.Unlock() if slf.gw != nil { slf.gw(packet) return } packet = slf.server.OnConnectionWritePacketBeforeEvent(slf, packet) + slf.closeL.Lock() + defer slf.closeL.Unlock() if slf.packetPool == nil || slf.packets == nil { return } @@ -276,10 +244,16 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) { defer func() { if err := recover(); err != nil { slf.Close() + slf.packets = nil } }() wait.Done() - for packet := range slf.packets { + for { + packet, ok := <-slf.packets + if !ok { + slf.packets = nil + break + } data := packet var err error @@ -298,8 +272,9 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) { } } callback := data.callback + slf.closeL.Lock() slf.packetPool.Release(data) - + slf.closeL.Unlock() if callback != nil { callback(err) } @@ -308,3 +283,31 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) { } } } + +// Close 关闭连接 +func (slf *Conn) Close(err ...error) { + slf.close.Do(func() { + slf.closeL.Lock() + defer slf.closeL.Unlock() + slf.closed = true + if slf.ws != nil { + _ = slf.ws.Close() + } else if slf.gn != nil { + _ = slf.gn.Close() + } else if slf.kcp != nil { + _ = slf.kcp.Close() + } + if slf.packetPool != nil { + slf.packetPool.Close() + } + slf.packetPool = nil + if slf.packets != nil { + close(slf.packets) + } + if len(err) > 0 { + slf.server.OnConnectionClosedEvent(slf, err[0]) + return + } + slf.server.OnConnectionClosedEvent(slf, nil) + }) +} diff --git a/server/gateway/endpoint.go b/server/gateway/endpoint.go index ce74ea7..c5376a1 100644 --- a/server/gateway/endpoint.go +++ b/server/gateway/endpoint.go @@ -5,6 +5,7 @@ import ( "github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server/client" "github.com/kercylan98/minotaur/utils/log" + "go.uber.org/atomic" "sync" "time" ) @@ -38,7 +39,7 @@ type Endpoint struct { client []*client.Client // 端点客户端 name string // 端点名称 address string // 端点地址 - state float64 // 端点健康值(0为不可用,越高越优) + state atomic.Float64 // 端点健康值(0为不可用,越高越优) evaluator func(costUnixNano float64) float64 // 端点健康值评估函数 connections *haxmap.Map[string, *server.Conn] // 被该端点转发的连接列表 rci time.Duration // 端点重连间隔 @@ -50,13 +51,13 @@ func (slf *Endpoint) start(gateway *Gateway, cli *client.Client) { for { cur := time.Now().UnixNano() if err := cli.Run(); err == nil { - slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur)) + slf.state.Swap(slf.evaluator(float64(time.Now().UnixNano() - cur))) break } if slf.rci > 0 { time.Sleep(slf.rci) } else { - slf.state = 0 + slf.state.Swap(0) break } } @@ -83,7 +84,7 @@ func (slf *Endpoint) connect(gateway *Gateway) { log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.Err(err)) return } - slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime)) + slf.state.Swap(slf.evaluator(float64(time.Now().UnixNano() - sendTime))) c, ok := slf.connections.Get(addr) if !ok { log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.String("ConnAddr", addr), log.Err(ErrConnectionNotFount)) @@ -111,7 +112,7 @@ func (slf *Endpoint) GetAddress() string { // GetState 获取端点健康值 func (slf *Endpoint) GetState() float64 { - return slf.state + return slf.state.Load() } // Forward 转发数据包到该端点 diff --git a/server/gateway/gateway.go b/server/gateway/gateway.go index 8d3569b..977583c 100644 --- a/server/gateway/gateway.go +++ b/server/gateway/gateway.go @@ -127,7 +127,7 @@ func (slf *Gateway) GetEndpoint(name string) (*Endpoint, error) { var available = make([]*Endpoint, 0, len(endpoints)) for _, e := range endpoints { - if e.state > 0 { + if e.GetState() > 0 { available = append(available, e) } } @@ -150,7 +150,7 @@ func (slf *Gateway) GetConnEndpoint(name string, conn *server.Conn) (*Endpoint, slf.cceLock.RLock() endpoint, exist := slf.cce[conn.GetID()] slf.cceLock.RUnlock() - if exist && endpoint.state > 0 { + if exist && endpoint.GetState() > 0 { return endpoint, nil } return slf.GetEndpoint(name) @@ -158,7 +158,7 @@ func (slf *Gateway) GetConnEndpoint(name string, conn *server.Conn) (*Endpoint, // SwitchEndpoint 将端点端点的所有连接切换到另一个端点 func (slf *Gateway) SwitchEndpoint(source, dest *Endpoint) { - if source.name == dest.name && source.address == dest.address || source.state <= 0 || dest.state <= 0 { + if source.name == dest.name && source.address == dest.address || source.GetState() <= 0 || dest.GetState() <= 0 { return } slf.cceLock.Lock()