diff --git a/server/conn.go b/server/conn.go index 72c863f..89a5ae5 100644 --- a/server/conn.go +++ b/server/conn.go @@ -3,6 +3,7 @@ package server import ( "github.com/gorilla/websocket" "github.com/kercylan98/minotaur/utils/concurrent" + "github.com/kercylan98/minotaur/utils/super" "github.com/panjf2000/gnet" "github.com/xtaci/kcp-go/v5" "net" @@ -65,6 +66,26 @@ func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn { return c } +// newGatewayConn 创建一个处理网关消息的连接 +func newGatewayConn(conn *Conn, connId string) *Conn { + c := &Conn{ + server: conn.server, + data: map[any]any{}, + } + c.gw = func(packet Packet) { + var gp = GP{ + C: connId, + WT: packet.WebsocketType, + D: packet.Data, + T: time.Now().UnixNano(), + } + pd := super.MarshalJSON(&gp) + packet.Data = append(pd, 0xff) + conn.Write(packet) + } + return c +} + // NewEmptyConn 创建一个适用于测试的空连接 func NewEmptyConn(server *Server) *Conn { c := &Conn{ @@ -88,6 +109,7 @@ type Conn struct { ws *websocket.Conn gn gnet.Conn kcp *kcp.UDPSession + gw func(packet Packet) data map[any]any mutex sync.Mutex packetPool *concurrent.Pool[*connPacket] @@ -96,7 +118,7 @@ type Conn struct { // IsEmpty 是否是空连接 func (slf *Conn) IsEmpty() bool { - return slf.ws == nil && slf.gn == nil && slf.kcp == nil + return slf.ws == nil && slf.gn == nil && slf.kcp == nil && slf.gw == nil } // Reuse 重用连接 @@ -179,6 +201,10 @@ func (slf *Conn) IsWebsocket() bool { // Write 向连接中写入数据 // - messageType: websocket模式中指定消息类型 func (slf *Conn) Write(packet Packet) { + if slf.gw != nil { + slf.gw(packet) + return + } packet = slf.server.OnConnectionWritePacketBeforeEvent(slf, packet) if slf.packetPool == nil { return @@ -194,6 +220,10 @@ func (slf *Conn) Write(packet Packet) { // WriteWithCallback 与 Write 相同,但是会在写入完成后调用 callback // - 当 callback 为 nil 时,与 Write 相同 func (slf *Conn) WriteWithCallback(packet Packet, callback func(err error)) { + if slf.gw != nil { + slf.gw(packet) + return + } packet = slf.server.OnConnectionWritePacketBeforeEvent(slf, packet) if slf.packetPool == nil { return diff --git a/server/gateway/endpoint.go b/server/gateway/endpoint.go index 86fb05b..69edfad 100644 --- a/server/gateway/endpoint.go +++ b/server/gateway/endpoint.go @@ -3,16 +3,25 @@ package gateway import ( "github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server/client" + "github.com/kercylan98/minotaur/utils/super" "time" ) // NewEndpoint 创建网关端点 -func NewEndpoint(name, address string) *Endpoint { +func NewEndpoint(name, address string, options ...EndpointOption) *Endpoint { endpoint := &Endpoint{ client: client.NewWebsocket(address), name: name, address: address, } + for _, option := range options { + option(endpoint) + } + if endpoint.evaluator == nil { + endpoint.evaluator = func(costUnixNano float64) float64 { + return 1 / (1 + 1.5*time.Duration(costUnixNano).Seconds()) + } + } endpoint.client.RegConnectionClosedEvent(endpoint.onConnectionClosed) endpoint.client.RegConnectionReceivePacketEvent(endpoint.onConnectionReceivePacket) return endpoint @@ -20,11 +29,12 @@ func NewEndpoint(name, address string) *Endpoint { // Endpoint 网关端点 type Endpoint struct { - client *client.Websocket // 端点客户端 - name string // 端点名称 - address string // 端点地址 - state float64 // 端点健康值(0为不可用,越高越优) - offline bool // 离线 + client *client.Websocket // 端点客户端 + name string // 端点名称 + address string // 端点地址 + state float64 // 端点健康值(0为不可用,越高越优) + offline bool // 离线 + evaluator func(costUnixNano float64) float64 // 端点健康值评估函数 } // Offline 离线 @@ -35,9 +45,9 @@ func (slf *Endpoint) Offline() { // Connect 连接端点 func (slf *Endpoint) Connect() { for { - var now = time.Now() + cur := time.Now().UnixNano() if err := slf.client.Run(); err == nil { - slf.state = 1 - (time.Since(now).Seconds() / 10) + slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur)) break } time.Sleep(100 * time.Millisecond) @@ -56,9 +66,13 @@ func (slf *Endpoint) onConnectionClosed(conn *client.Websocket, err any) { } } -// onConnectionReceivePacket 解说到来自端点的数据包事件 +// onConnectionReceivePacket 接收到来自端点的数据包事件 func (slf *Endpoint) onConnectionReceivePacket(conn *client.Websocket, packet server.Packet) { - p := UnpackGatewayPacket(packet) - packet.Data = p.Data - conn.GetData(p.ConnID).(*server.Conn).Write(packet) + var gp server.GP + if err := super.UnmarshalJSON(packet.Data[:len(packet.Data)-1], &gp); err != nil { + panic(err) + } + cur := time.Now().UnixNano() + slf.state = slf.evaluator(float64(cur - gp.T)) + conn.GetData(gp.C).(*server.Conn).Write(server.NewWSPacket(gp.WT, gp.D)) } diff --git a/server/gateway/endpoint_options.go b/server/gateway/endpoint_options.go new file mode 100644 index 0000000..d9c2f41 --- /dev/null +++ b/server/gateway/endpoint_options.go @@ -0,0 +1,11 @@ +package gateway + +// EndpointOption 网关端点选项 +type EndpointOption func(endpoint *Endpoint) + +// WithEndpointStateEvaluator 设置端点健康值评估函数 +func WithEndpointStateEvaluator(evaluator func(costUnixNano float64) float64) EndpointOption { + return func(endpoint *Endpoint) { + endpoint.evaluator = evaluator + } +} diff --git a/server/gateway/gateway.go b/server/gateway/gateway.go index e50157d..21a7d58 100644 --- a/server/gateway/gateway.go +++ b/server/gateway/gateway.go @@ -48,27 +48,12 @@ func (slf *Gateway) onConnectionOpened(srv *server.Server, conn *server.Conn) { // onConnectionReceivePacket 连接接收数据包事件 func (slf *Gateway) onConnectionReceivePacket(srv *server.Server, conn *server.Conn, packet server.Packet) { - conn.GetData("endpoint").(*Endpoint).Write(PackGatewayPacket(conn.GetID(), packet.WebsocketType, packet.Data)) -} - -// PackGatewayPacket 打包网关数据包 -func PackGatewayPacket(connID string, websocketType int, data []byte) server.Packet { - var gatewayPacket = Packet{ - ConnID: connID, - WebsocketType: websocketType, - Data: data, - } - return server.Packet{ - WebsocketType: websocketType, - Data: super.MarshalJSON(&gatewayPacket), + var gp = server.GP{ + C: conn.GetID(), + WT: packet.WebsocketType, + D: packet.Data, } -} - -// UnpackGatewayPacket 解包网关数据包 -func UnpackGatewayPacket(packet server.Packet) Packet { - var gatewayPacket Packet - if err := super.UnmarshalJSON(packet.Data, &gatewayPacket); err != nil { - panic(err) - } - return gatewayPacket + pd := super.MarshalJSON(&gp) + packet.Data = append(pd, 0xff) + conn.GetData("endpoint").(*Endpoint).Write(packet) } diff --git a/server/gateway/gateway_test.go b/server/gateway/gateway_test.go index 3d78e1d..08b9951 100644 --- a/server/gateway/gateway_test.go +++ b/server/gateway/gateway_test.go @@ -10,8 +10,7 @@ import ( func TestGateway_RunEndpointServer(t *testing.T) { srv := server.New(server.NetworkWebsocket) srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet server.Packet) { - p := gateway2.UnpackGatewayPacket(packet) - fmt.Println("endpoint receive packet", string(p.Data)) + fmt.Println("endpoint receive packet", string(packet.Data)) conn.Write(packet) }) if err := srv.Run(":8889"); err != nil { diff --git a/server/gp.go b/server/gp.go new file mode 100644 index 0000000..e84defd --- /dev/null +++ b/server/gp.go @@ -0,0 +1,8 @@ +package server + +type GP struct { + C string // 连接 ID + WT int // WebSocket 类型 + D []byte // 数据 + T int64 // 时间戳 +} diff --git a/server/server.go b/server/server.go index 61f74eb..3305916 100644 --- a/server/server.go +++ b/server/server.go @@ -631,8 +631,19 @@ func (slf *Server) dispatchMessage(msg *Message) { var conn = attrs[0].(*Conn) var packet = attrs[1].([]byte) var wst = int(packet[len(packet)-1]) + var ct = packet[len(packet)-2] + if ct == 0xff { + var gp GP + if err := super.UnmarshalJSON(packet[:len(packet)-2], &gp); err != nil { + panic(err) + } + packet = gp.D + conn = newGatewayConn(conn, gp.C) + } else { + packet = packet[:len(packet)-1] + } if !slf.OnConnectionPacketPreprocessEvent(conn, packet, func(newPacket []byte) { packet = newPacket }) { - slf.OnConnectionReceivePacketEvent(conn, Packet{Data: packet[:len(packet)-1], WebsocketType: wst}) + slf.OnConnectionReceivePacketEvent(conn, Packet{Data: packet, WebsocketType: wst}) } case MessageTypeError: err, action := attrs[0].(error), attrs[1].(MessageErrorAction)