From 3ca6ed00ec91c34a4a61a61dcfd5731da8faba66 Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Fri, 1 Sep 2023 14:30:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20gateway=20=E6=96=B0=E5=A2=9E=20WithEndp?= =?UTF-8?q?ointConnectionPoolSize=20=E6=94=AF=E6=8C=81=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E4=B8=8E=E7=AB=AF=E7=82=B9=E5=BB=BA=E7=AB=8B=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E7=9A=84=E6=95=B0=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/gateway/endpoint.go | 104 +++++++++++++++++------------ server/gateway/endpoint_options.go | 23 ++++++- server/gateway/gateway_test.go | 6 +- 3 files changed, 85 insertions(+), 48 deletions(-) diff --git a/server/gateway/endpoint.go b/server/gateway/endpoint.go index 5a838d8..30fde8a 100644 --- a/server/gateway/endpoint.go +++ b/server/gateway/endpoint.go @@ -5,23 +5,25 @@ import ( "github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server/client" "github.com/kercylan98/minotaur/utils/log" + "sync" "time" ) -var DefaultEndpointReconnectInterval = time.Second - // NewEndpoint 创建网关端点 func NewEndpoint(name string, cli *client.Client, options ...EndpointOption) *Endpoint { endpoint := &Endpoint{ - client: cli, name: name, address: cli.GetServerAddr(), connections: haxmap.New[string, *server.Conn](), rci: DefaultEndpointReconnectInterval, + cps: DefaultEndpointConnectionPoolSize, } for _, option := range options { option(endpoint) } + for i := 0; i < endpoint.cps; i++ { + endpoint.client = append(endpoint.client, client.CloneClient(cli)) + } if endpoint.evaluator == nil { endpoint.evaluator = func(costUnixNano float64) float64 { return 1 / (1 + 1.5*time.Duration(costUnixNano).Seconds()) @@ -33,55 +35,21 @@ func NewEndpoint(name string, cli *client.Client, options ...EndpointOption) *En // Endpoint 网关端点 type Endpoint struct { gateway *Gateway - client *client.Client // 端点客户端 + client []*client.Client // 端点客户端 name string // 端点名称 address string // 端点地址 state float64 // 端点健康值(0为不可用,越高越优) evaluator func(costUnixNano float64) float64 // 端点健康值评估函数 connections *haxmap.Map[string, *server.Conn] // 被该端点转发的连接列表 rci time.Duration // 端点重连间隔 + cps int // 端点连接池大小 } -// connect 连接端点 -func (slf *Endpoint) connect(gateway *Gateway) { - slf.gateway = gateway - slf.client.RegConnectionOpenedEvent(func(conn *client.Client) { - slf.gateway.OnEndpointConnectOpenedEvent(slf.gateway, slf) - }) - slf.client.RegConnectionClosedEvent(func(conn *client.Client, err any) { - slf.gateway.OnEndpointConnectClosedEvent(slf.gateway, slf) - for { - cur := time.Now().UnixNano() - if err := slf.client.Run(); err == nil { - slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur)) - break - } - if slf.rci > 0 { - time.Sleep(slf.rci) - } else { - slf.state = 0 - break - } - } - }) - slf.client.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) { - addr, sendTime, packet, err := UnmarshalGatewayInPacket(packet) - if err != nil { - log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.Err(err)) - return - } - slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime)) - c, ok := slf.connections.Get(addr) - if !ok { - log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.String("ConnAddr", addr), log.Err(ErrConnectionNotFount)) - return - } - c.SetWST(wst) - slf.gateway.OnEndpointConnectReceivePacketEvent(slf.gateway, slf, c, packet) - }) +// start 开始与端点建立连接 +func (slf *Endpoint) start(gateway *Gateway, cli *client.Client) { for { cur := time.Now().UnixNano() - if err := slf.client.Run(); err == nil { + if err := cli.Run(); err == nil { slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur)) break } @@ -94,6 +62,43 @@ func (slf *Endpoint) connect(gateway *Gateway) { } } +// connect 连接端点 +func (slf *Endpoint) connect(gateway *Gateway) { + slf.gateway = gateway + var least sync.WaitGroup + var leastOnce sync.Once + least.Add(1) + for _, cli := range slf.client { + go func(cli *client.Client) { + cli.RegConnectionOpenedEvent(func(conn *client.Client) { + slf.gateway.OnEndpointConnectOpenedEvent(slf.gateway, slf) + }) + cli.RegConnectionClosedEvent(func(conn *client.Client, err any) { + slf.gateway.OnEndpointConnectClosedEvent(slf.gateway, slf) + slf.start(gateway, cli) + }) + cli.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) { + addr, sendTime, packet, err := UnmarshalGatewayInPacket(packet) + if err != nil { + log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.Err(err)) + return + } + slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime)) + c, ok := slf.connections.Get(addr) + if !ok { + log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.String("ConnAddr", addr), log.Err(ErrConnectionNotFount)) + return + } + c.SetWST(wst) + slf.gateway.OnEndpointConnectReceivePacketEvent(slf.gateway, slf, c, packet) + }) + slf.start(gateway, cli) + leastOnce.Do(least.Done) + }(cli) + } + least.Wait() +} + // GetName 获取端点名称 func (slf *Endpoint) GetName() string { return slf.name @@ -131,9 +136,20 @@ func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func( slf.connections.Set(conn.GetID(), conn) } } + + var superior *client.Client + var superiorCount = -1 + for _, cli := range slf.client { + count := cli.GetMessageAccumulationTotal() + if superiorCount < 0 || superiorCount > count { + superior = cli + superiorCount = count + } + } + if conn.IsWebsocket() { - slf.client.WriteWS(conn.GetWST(), packet, cb) + superior.WriteWS(conn.GetWST(), packet, cb) } else { - slf.client.Write(packet, cb) + superior.Write(packet, cb) } } diff --git a/server/gateway/endpoint_options.go b/server/gateway/endpoint_options.go index 0e1063c..aaa840d 100644 --- a/server/gateway/endpoint_options.go +++ b/server/gateway/endpoint_options.go @@ -1,6 +1,13 @@ package gateway -import "time" +import ( + "time" +) + +const ( + DefaultEndpointReconnectInterval = time.Second + DefaultEndpointConnectionPoolSize = 1 +) // EndpointOption 网关端点选项 type EndpointOption func(endpoint *Endpoint) @@ -12,10 +19,20 @@ func WithEndpointStateEvaluator(evaluator func(costUnixNano float64) float64) En } } -// WithReconnectInterval 设置端点重连间隔 +// WithEndpointConnectionPoolSize 设置端点连接池大小 +// - 默认为 DefaultEndpointConnectionPoolSize +// - 端点连接池大小决定了网关服务器与端点服务器建立的连接数,如果 <= 0 则会使用默认值 +// - 在网关服务器中,多个客户端在发送消息到端点服务器时,会共用一个连接,适当的增大连接池大小可以提高网关服务器的承载能力 +func WithEndpointConnectionPoolSize(size int) EndpointOption { + return func(endpoint *Endpoint) { + endpoint.cps = size + } +} + +// WithEndpointReconnectInterval 设置端点重连间隔 // - 默认为 DefaultEndpointReconnectInterval // - 端点在连接失败后会在该间隔后重连,如果 <= 0 则不会重连 -func WithReconnectInterval(interval time.Duration) EndpointOption { +func WithEndpointReconnectInterval(interval time.Duration) EndpointOption { return func(endpoint *Endpoint) { endpoint.rci = interval } diff --git a/server/gateway/gateway_test.go b/server/gateway/gateway_test.go index ba63006..1bb2f56 100644 --- a/server/gateway/gateway_test.go +++ b/server/gateway/gateway_test.go @@ -4,6 +4,7 @@ import ( "github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server/client" "github.com/kercylan98/minotaur/server/gateway" + "github.com/kercylan98/minotaur/utils/log" "testing" "time" ) @@ -13,7 +14,7 @@ type Scanner struct { func (slf *Scanner) GetEndpoints() ([]*gateway.Endpoint, error) { return []*gateway.Endpoint{ - gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889")), + gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889"), gateway.WithEndpointConnectionPoolSize(10)), }, nil } @@ -46,6 +47,9 @@ func TestGateway_RunEndpointServer(t *testing.T) { srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { conn.Write(packet) }) + srv.RegConnectionOpenedEvent(func(srv *server.Server, conn *server.Conn) { + log.Info("connection opened", log.String("conn", conn.GetID())) + }) if err := srv.Run(":8889"); err != nil { panic(err) }