feat: gateway 新增 WithEndpointConnectionPoolSize 支持配置与端点建立连接的数量

This commit is contained in:
kercylan98 2023-09-01 14:30:57 +08:00
parent 2ed52fc814
commit 3ca6ed00ec
3 changed files with 85 additions and 48 deletions

View File

@ -5,23 +5,25 @@ import (
"github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/server/client" "github.com/kercylan98/minotaur/server/client"
"github.com/kercylan98/minotaur/utils/log" "github.com/kercylan98/minotaur/utils/log"
"sync"
"time" "time"
) )
var DefaultEndpointReconnectInterval = time.Second
// NewEndpoint 创建网关端点 // NewEndpoint 创建网关端点
func NewEndpoint(name string, cli *client.Client, options ...EndpointOption) *Endpoint { func NewEndpoint(name string, cli *client.Client, options ...EndpointOption) *Endpoint {
endpoint := &Endpoint{ endpoint := &Endpoint{
client: cli,
name: name, name: name,
address: cli.GetServerAddr(), address: cli.GetServerAddr(),
connections: haxmap.New[string, *server.Conn](), connections: haxmap.New[string, *server.Conn](),
rci: DefaultEndpointReconnectInterval, rci: DefaultEndpointReconnectInterval,
cps: DefaultEndpointConnectionPoolSize,
} }
for _, option := range options { for _, option := range options {
option(endpoint) option(endpoint)
} }
for i := 0; i < endpoint.cps; i++ {
endpoint.client = append(endpoint.client, client.CloneClient(cli))
}
if endpoint.evaluator == nil { if endpoint.evaluator == nil {
endpoint.evaluator = func(costUnixNano float64) float64 { endpoint.evaluator = func(costUnixNano float64) float64 {
return 1 / (1 + 1.5*time.Duration(costUnixNano).Seconds()) return 1 / (1 + 1.5*time.Duration(costUnixNano).Seconds())
@ -33,55 +35,21 @@ func NewEndpoint(name string, cli *client.Client, options ...EndpointOption) *En
// Endpoint 网关端点 // Endpoint 网关端点
type Endpoint struct { type Endpoint struct {
gateway *Gateway gateway *Gateway
client *client.Client // 端点客户端 client []*client.Client // 端点客户端
name string // 端点名称 name string // 端点名称
address string // 端点地址 address string // 端点地址
state float64 // 端点健康值0为不可用越高越优 state float64 // 端点健康值0为不可用越高越优
evaluator func(costUnixNano float64) float64 // 端点健康值评估函数 evaluator func(costUnixNano float64) float64 // 端点健康值评估函数
connections *haxmap.Map[string, *server.Conn] // 被该端点转发的连接列表 connections *haxmap.Map[string, *server.Conn] // 被该端点转发的连接列表
rci time.Duration // 端点重连间隔 rci time.Duration // 端点重连间隔
cps int // 端点连接池大小
} }
// connect 连接端点 // start 开始与端点建立连接
func (slf *Endpoint) connect(gateway *Gateway) { func (slf *Endpoint) start(gateway *Gateway, cli *client.Client) {
slf.gateway = gateway
slf.client.RegConnectionOpenedEvent(func(conn *client.Client) {
slf.gateway.OnEndpointConnectOpenedEvent(slf.gateway, slf)
})
slf.client.RegConnectionClosedEvent(func(conn *client.Client, err any) {
slf.gateway.OnEndpointConnectClosedEvent(slf.gateway, slf)
for {
cur := time.Now().UnixNano()
if err := slf.client.Run(); err == nil {
slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur))
break
}
if slf.rci > 0 {
time.Sleep(slf.rci)
} else {
slf.state = 0
break
}
}
})
slf.client.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) {
addr, sendTime, packet, err := UnmarshalGatewayInPacket(packet)
if err != nil {
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.Err(err))
return
}
slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime))
c, ok := slf.connections.Get(addr)
if !ok {
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.String("ConnAddr", addr), log.Err(ErrConnectionNotFount))
return
}
c.SetWST(wst)
slf.gateway.OnEndpointConnectReceivePacketEvent(slf.gateway, slf, c, packet)
})
for { for {
cur := time.Now().UnixNano() cur := time.Now().UnixNano()
if err := slf.client.Run(); err == nil { if err := cli.Run(); err == nil {
slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur)) slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur))
break break
} }
@ -94,6 +62,43 @@ func (slf *Endpoint) connect(gateway *Gateway) {
} }
} }
// connect 连接端点
func (slf *Endpoint) connect(gateway *Gateway) {
slf.gateway = gateway
var least sync.WaitGroup
var leastOnce sync.Once
least.Add(1)
for _, cli := range slf.client {
go func(cli *client.Client) {
cli.RegConnectionOpenedEvent(func(conn *client.Client) {
slf.gateway.OnEndpointConnectOpenedEvent(slf.gateway, slf)
})
cli.RegConnectionClosedEvent(func(conn *client.Client, err any) {
slf.gateway.OnEndpointConnectClosedEvent(slf.gateway, slf)
slf.start(gateway, cli)
})
cli.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) {
addr, sendTime, packet, err := UnmarshalGatewayInPacket(packet)
if err != nil {
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.Err(err))
return
}
slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime))
c, ok := slf.connections.Get(addr)
if !ok {
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.String("ConnAddr", addr), log.Err(ErrConnectionNotFount))
return
}
c.SetWST(wst)
slf.gateway.OnEndpointConnectReceivePacketEvent(slf.gateway, slf, c, packet)
})
slf.start(gateway, cli)
leastOnce.Do(least.Done)
}(cli)
}
least.Wait()
}
// GetName 获取端点名称 // GetName 获取端点名称
func (slf *Endpoint) GetName() string { func (slf *Endpoint) GetName() string {
return slf.name return slf.name
@ -131,9 +136,20 @@ func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func(
slf.connections.Set(conn.GetID(), conn) slf.connections.Set(conn.GetID(), conn)
} }
} }
var superior *client.Client
var superiorCount = -1
for _, cli := range slf.client {
count := cli.GetMessageAccumulationTotal()
if superiorCount < 0 || superiorCount > count {
superior = cli
superiorCount = count
}
}
if conn.IsWebsocket() { if conn.IsWebsocket() {
slf.client.WriteWS(conn.GetWST(), packet, cb) superior.WriteWS(conn.GetWST(), packet, cb)
} else { } else {
slf.client.Write(packet, cb) superior.Write(packet, cb)
} }
} }

View File

@ -1,6 +1,13 @@
package gateway package gateway
import "time" import (
"time"
)
const (
DefaultEndpointReconnectInterval = time.Second
DefaultEndpointConnectionPoolSize = 1
)
// EndpointOption 网关端点选项 // EndpointOption 网关端点选项
type EndpointOption func(endpoint *Endpoint) type EndpointOption func(endpoint *Endpoint)
@ -12,10 +19,20 @@ func WithEndpointStateEvaluator(evaluator func(costUnixNano float64) float64) En
} }
} }
// WithReconnectInterval 设置端点重连间隔 // WithEndpointConnectionPoolSize 设置端点连接池大小
// - 默认为 DefaultEndpointConnectionPoolSize
// - 端点连接池大小决定了网关服务器与端点服务器建立的连接数,如果 <= 0 则会使用默认值
// - 在网关服务器中,多个客户端在发送消息到端点服务器时,会共用一个连接,适当的增大连接池大小可以提高网关服务器的承载能力
func WithEndpointConnectionPoolSize(size int) EndpointOption {
return func(endpoint *Endpoint) {
endpoint.cps = size
}
}
// WithEndpointReconnectInterval 设置端点重连间隔
// - 默认为 DefaultEndpointReconnectInterval // - 默认为 DefaultEndpointReconnectInterval
// - 端点在连接失败后会在该间隔后重连,如果 <= 0 则不会重连 // - 端点在连接失败后会在该间隔后重连,如果 <= 0 则不会重连
func WithReconnectInterval(interval time.Duration) EndpointOption { func WithEndpointReconnectInterval(interval time.Duration) EndpointOption {
return func(endpoint *Endpoint) { return func(endpoint *Endpoint) {
endpoint.rci = interval endpoint.rci = interval
} }

View File

@ -4,6 +4,7 @@ import (
"github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/server/client" "github.com/kercylan98/minotaur/server/client"
"github.com/kercylan98/minotaur/server/gateway" "github.com/kercylan98/minotaur/server/gateway"
"github.com/kercylan98/minotaur/utils/log"
"testing" "testing"
"time" "time"
) )
@ -13,7 +14,7 @@ type Scanner struct {
func (slf *Scanner) GetEndpoints() ([]*gateway.Endpoint, error) { func (slf *Scanner) GetEndpoints() ([]*gateway.Endpoint, error) {
return []*gateway.Endpoint{ return []*gateway.Endpoint{
gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889")), gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889"), gateway.WithEndpointConnectionPoolSize(10)),
}, nil }, nil
} }
@ -46,6 +47,9 @@ func TestGateway_RunEndpointServer(t *testing.T) {
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
conn.Write(packet) conn.Write(packet)
}) })
srv.RegConnectionOpenedEvent(func(srv *server.Server, conn *server.Conn) {
log.Info("connection opened", log.String("conn", conn.GetID()))
})
if err := srv.Run(":8889"); err != nil { if err := srv.Run(":8889"); err != nil {
panic(err) panic(err)
} }