feat: gateway 新增 WithEndpointConnectionPoolSize 支持配置与端点建立连接的数量

This commit is contained in:
kercylan98 2023-09-01 14:30:57 +08:00
parent 2ed52fc814
commit 3ca6ed00ec
3 changed files with 85 additions and 48 deletions

View File

@ -5,23 +5,25 @@ import (
"github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/server/client"
"github.com/kercylan98/minotaur/utils/log"
"sync"
"time"
)
var DefaultEndpointReconnectInterval = time.Second
// NewEndpoint 创建网关端点
func NewEndpoint(name string, cli *client.Client, options ...EndpointOption) *Endpoint {
endpoint := &Endpoint{
client: cli,
name: name,
address: cli.GetServerAddr(),
connections: haxmap.New[string, *server.Conn](),
rci: DefaultEndpointReconnectInterval,
cps: DefaultEndpointConnectionPoolSize,
}
for _, option := range options {
option(endpoint)
}
for i := 0; i < endpoint.cps; i++ {
endpoint.client = append(endpoint.client, client.CloneClient(cli))
}
if endpoint.evaluator == nil {
endpoint.evaluator = func(costUnixNano float64) float64 {
return 1 / (1 + 1.5*time.Duration(costUnixNano).Seconds())
@ -33,26 +35,21 @@ func NewEndpoint(name string, cli *client.Client, options ...EndpointOption) *En
// Endpoint 网关端点
type Endpoint struct {
gateway *Gateway
client *client.Client // 端点客户端
client []*client.Client // 端点客户端
name string // 端点名称
address string // 端点地址
state float64 // 端点健康值0为不可用越高越优
evaluator func(costUnixNano float64) float64 // 端点健康值评估函数
connections *haxmap.Map[string, *server.Conn] // 被该端点转发的连接列表
rci time.Duration // 端点重连间隔
cps int // 端点连接池大小
}
// connect 连接端点
func (slf *Endpoint) connect(gateway *Gateway) {
slf.gateway = gateway
slf.client.RegConnectionOpenedEvent(func(conn *client.Client) {
slf.gateway.OnEndpointConnectOpenedEvent(slf.gateway, slf)
})
slf.client.RegConnectionClosedEvent(func(conn *client.Client, err any) {
slf.gateway.OnEndpointConnectClosedEvent(slf.gateway, slf)
// start 开始与端点建立连接
func (slf *Endpoint) start(gateway *Gateway, cli *client.Client) {
for {
cur := time.Now().UnixNano()
if err := slf.client.Run(); err == nil {
if err := cli.Run(); err == nil {
slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur))
break
}
@ -63,8 +60,24 @@ func (slf *Endpoint) connect(gateway *Gateway) {
break
}
}
}
// connect 连接端点
func (slf *Endpoint) connect(gateway *Gateway) {
slf.gateway = gateway
var least sync.WaitGroup
var leastOnce sync.Once
least.Add(1)
for _, cli := range slf.client {
go func(cli *client.Client) {
cli.RegConnectionOpenedEvent(func(conn *client.Client) {
slf.gateway.OnEndpointConnectOpenedEvent(slf.gateway, slf)
})
slf.client.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) {
cli.RegConnectionClosedEvent(func(conn *client.Client, err any) {
slf.gateway.OnEndpointConnectClosedEvent(slf.gateway, slf)
slf.start(gateway, cli)
})
cli.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) {
addr, sendTime, packet, err := UnmarshalGatewayInPacket(packet)
if err != nil {
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.Err(err))
@ -79,19 +92,11 @@ func (slf *Endpoint) connect(gateway *Gateway) {
c.SetWST(wst)
slf.gateway.OnEndpointConnectReceivePacketEvent(slf.gateway, slf, c, packet)
})
for {
cur := time.Now().UnixNano()
if err := slf.client.Run(); err == nil {
slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur))
break
}
if slf.rci > 0 {
time.Sleep(slf.rci)
} else {
slf.state = 0
break
}
slf.start(gateway, cli)
leastOnce.Do(least.Done)
}(cli)
}
least.Wait()
}
// GetName 获取端点名称
@ -131,9 +136,20 @@ func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func(
slf.connections.Set(conn.GetID(), conn)
}
}
var superior *client.Client
var superiorCount = -1
for _, cli := range slf.client {
count := cli.GetMessageAccumulationTotal()
if superiorCount < 0 || superiorCount > count {
superior = cli
superiorCount = count
}
}
if conn.IsWebsocket() {
slf.client.WriteWS(conn.GetWST(), packet, cb)
superior.WriteWS(conn.GetWST(), packet, cb)
} else {
slf.client.Write(packet, cb)
superior.Write(packet, cb)
}
}

View File

@ -1,6 +1,13 @@
package gateway
import "time"
import (
"time"
)
const (
DefaultEndpointReconnectInterval = time.Second
DefaultEndpointConnectionPoolSize = 1
)
// EndpointOption 网关端点选项
type EndpointOption func(endpoint *Endpoint)
@ -12,10 +19,20 @@ func WithEndpointStateEvaluator(evaluator func(costUnixNano float64) float64) En
}
}
// WithReconnectInterval 设置端点重连间隔
// WithEndpointConnectionPoolSize 设置端点连接池大小
// - 默认为 DefaultEndpointConnectionPoolSize
// - 端点连接池大小决定了网关服务器与端点服务器建立的连接数,如果 <= 0 则会使用默认值
// - 在网关服务器中,多个客户端在发送消息到端点服务器时,会共用一个连接,适当的增大连接池大小可以提高网关服务器的承载能力
func WithEndpointConnectionPoolSize(size int) EndpointOption {
return func(endpoint *Endpoint) {
endpoint.cps = size
}
}
// WithEndpointReconnectInterval 设置端点重连间隔
// - 默认为 DefaultEndpointReconnectInterval
// - 端点在连接失败后会在该间隔后重连,如果 <= 0 则不会重连
func WithReconnectInterval(interval time.Duration) EndpointOption {
func WithEndpointReconnectInterval(interval time.Duration) EndpointOption {
return func(endpoint *Endpoint) {
endpoint.rci = interval
}

View File

@ -4,6 +4,7 @@ import (
"github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/server/client"
"github.com/kercylan98/minotaur/server/gateway"
"github.com/kercylan98/minotaur/utils/log"
"testing"
"time"
)
@ -13,7 +14,7 @@ type Scanner struct {
func (slf *Scanner) GetEndpoints() ([]*gateway.Endpoint, error) {
return []*gateway.Endpoint{
gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889")),
gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889"), gateway.WithEndpointConnectionPoolSize(10)),
}, nil
}
@ -46,6 +47,9 @@ func TestGateway_RunEndpointServer(t *testing.T) {
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
conn.Write(packet)
})
srv.RegConnectionOpenedEvent(func(srv *server.Server, conn *server.Conn) {
log.Info("connection opened", log.String("conn", conn.GetID()))
})
if err := srv.Run(":8889"); err != nil {
panic(err)
}