refactor: gateway 整体优化重构

This commit is contained in:
kercylan98 2023-08-24 16:40:03 +08:00
parent a3bb10012e
commit 30e7894a37
8 changed files with 323 additions and 187 deletions

View File

@ -4,16 +4,16 @@ import (
"github.com/alphadose/haxmap"
"github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/server/client"
"github.com/kercylan98/minotaur/utils/log"
"time"
)
// NewEndpoint 创建网关端点
func NewEndpoint(gateway *Gateway, name string, client *client.Client, options ...EndpointOption) *Endpoint {
func NewEndpoint(name string, cli *client.Client, options ...EndpointOption) *Endpoint {
endpoint := &Endpoint{
gateway: gateway,
client: client,
client: cli,
name: name,
address: client.GetServerAddr(),
address: cli.GetServerAddr(),
connections: haxmap.New[string, *server.Conn](),
}
for _, option := range options {
@ -24,8 +24,6 @@ func NewEndpoint(gateway *Gateway, name string, client *client.Client, options .
return 1 / (1 + 1.5*time.Duration(costUnixNano).Seconds())
}
}
endpoint.client.RegConnectionClosedEvent(endpoint.onConnectionClosed)
endpoint.client.RegConnectionReceivePacketEvent(endpoint.onConnectionReceivePacket)
return endpoint
}
@ -36,34 +34,42 @@ type Endpoint struct {
name string // 端点名称
address string // 端点地址
state float64 // 端点健康值0为不可用越高越优
offline bool // 离线
evaluator func(costUnixNano float64) float64 // 端点健康值评估函数
connections *haxmap.Map[string, *server.Conn] // 连接列表
connections *haxmap.Map[string, *server.Conn] // 被该端点转发的连接列表
}
// Link 连接端点
func (slf *Endpoint) Link(conn *server.Conn) {
slf.connections.Set(conn.GetID(), conn)
// connect 连接端点
func (slf *Endpoint) connect(gateway *Gateway) {
slf.gateway = gateway
slf.client.RegConnectionOpenedEvent(func(conn *client.Client) {
slf.gateway.OnEndpointConnectOpenedEvent(slf.gateway, slf)
})
slf.client.RegConnectionClosedEvent(func(conn *client.Client, err any) {
slf.gateway.OnEndpointConnectClosedEvent(slf.gateway, slf)
for {
cur := time.Now().UnixNano()
if err := slf.client.Run(); err == nil {
slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur))
break
}
// Unlink 断开连接
func (slf *Endpoint) Unlink(conn *server.Conn) {
slf.connections.Del(conn.GetID())
time.Sleep(1000 * time.Millisecond)
}
// GetLink 获取连接
func (slf *Endpoint) GetLink(id string) *server.Conn {
conn, _ := slf.connections.Get(id)
return conn
})
slf.client.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) {
addr, sendTime, packet, err := UnmarshalGatewayInPacket(packet)
if err != nil {
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.Err(err))
return
}
// Offline 离线
func (slf *Endpoint) Offline() {
slf.offline = true
slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime))
c, ok := slf.connections.Get(addr)
if !ok {
log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.String("ConnAddr", addr), log.Err(ErrConnectionNotFount))
return
}
// Connect 连接端点
func (slf *Endpoint) Connect() {
c.SetWST(wst)
slf.gateway.OnEndpointConnectReceivePacketEvent(slf.gateway, slf, c, packet)
})
for {
cur := time.Now().UnixNano()
if err := slf.client.Run(); err == nil {
@ -74,33 +80,46 @@ func (slf *Endpoint) Connect() {
}
}
// Write 写入数据
func (slf *Endpoint) Write(packet []byte, callback ...func(err error)) {
slf.client.Write(packet, callback...)
// GetName 获取端点名称
func (slf *Endpoint) GetName() string {
return slf.name
}
// WriteWS 写入 websocket 数据
func (slf *Endpoint) WriteWS(wst int, packet []byte, callback ...func(err error)) {
slf.client.WriteWS(wst, packet, callback...)
// GetAddress 获取端点地址
func (slf *Endpoint) GetAddress() string {
return slf.address
}
// onConnectionClosed 与端点连接断开事件
func (slf *Endpoint) onConnectionClosed(conn *client.Client, err any) {
if !slf.offline {
go slf.Connect()
}
// GetState 获取端点健康值
func (slf *Endpoint) GetState() float64 {
return slf.state
}
// onConnectionReceivePacket 接收到来自端点的数据包事件
func (slf *Endpoint) onConnectionReceivePacket(conn *client.Client, wst int, packet []byte) {
addr, sendTime, packet, err := UnmarshalGatewayInPacket(packet)
// Forward 转发数据包到该端点
// - 端点在处理数据包时,应区分数据包为普通直连数据包还是网关数据包。可通过 UnmarshalGatewayOutPacket 进行数据包解析,当解析失败且无其他数据包协议时,可认为该数据包为普通直连数据包。
func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func(err error)) {
var err error
packet, err = MarshalGatewayOutPacket(conn.GetID(), packet)
if err != nil {
panic(err)
if len(callback) > 0 {
callback[0](err)
}
slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime))
cli := slf.GetLink(addr)
if cli == nil {
return
}
cli.SetWST(wst).Write(packet)
cb := func(err error) {
if len(callback) > 0 {
callback[0](err)
}
if err != nil {
slf.connections.Del(conn.GetID())
} else {
slf.connections.Set(conn.GetID(), conn)
}
}
if conn.IsWebsocket() {
slf.client.WriteWS(conn.GetWST(), packet, cb)
} else {
slf.client.Write(packet, cb)
}
}

View File

@ -1,96 +0,0 @@
package gateway
import (
"github.com/alphadose/haxmap"
"github.com/kercylan98/minotaur/utils/concurrent"
"github.com/kercylan98/minotaur/utils/random"
)
// NewEndpointManager 创建网关端点管理器
func NewEndpointManager() *EndpointManager {
em := &EndpointManager{
endpoints: concurrent.NewBalanceMap[string, []*Endpoint](),
memory: haxmap.New[string, *Endpoint](),
selector: func(endpoints []*Endpoint) *Endpoint {
return endpoints[random.Int(0, len(endpoints)-1)]
},
}
return em
}
// EndpointManager 网关端点管理器
type EndpointManager struct {
endpoints *concurrent.BalanceMap[string, []*Endpoint]
memory *haxmap.Map[string, *Endpoint]
selector func([]*Endpoint) *Endpoint
}
// GetEndpoint 获取端点
// - name: 端点名称
// - id: 使用端点的连接标识
func (slf *EndpointManager) GetEndpoint(name, id string) (*Endpoint, error) {
endpoint, exist := slf.memory.Get(id)
if exist {
return endpoint, nil
}
slf.endpoints.Atom(func(m map[string][]*Endpoint) {
endpoints, exist := m[name]
if !exist {
return
}
if len(endpoints) == 0 {
return
}
var available = make([]*Endpoint, 0, len(endpoints))
for _, e := range endpoints {
if !e.offline && e.state > 0 {
available = append(available, e)
}
}
if len(available) == 0 {
return
}
endpoint = slf.selector(available)
})
if endpoint == nil {
return nil, ErrEndpointNotExists
}
slf.memory.Set(id, endpoint)
return endpoint, nil
}
// AddEndpoint 添加端点
func (slf *EndpointManager) AddEndpoint(endpoint *Endpoint) error {
if endpoint.client.IsConnected() {
return ErrCannotAddRunningEndpoint
}
for _, e := range slf.endpoints.Get(endpoint.name) {
if e.address == endpoint.address {
return ErrEndpointAlreadyExists
}
}
go endpoint.Connect()
slf.endpoints.Atom(func(m map[string][]*Endpoint) {
m[endpoint.name] = append(m[endpoint.name], endpoint)
})
return nil
}
// RemoveEndpoint 移除端点
func (slf *EndpointManager) RemoveEndpoint(endpoint *Endpoint) error {
slf.endpoints.Atom(func(m map[string][]*Endpoint) {
var endpoints []*Endpoint
endpoints, exist := m[endpoint.name]
if !exist {
return
}
for i, e := range endpoints {
if e.address == endpoint.address {
endpoints = append(endpoints[:i], endpoints[i+1:]...)
break
}
}
m[endpoint.name] = endpoints
})
return nil
}

View File

@ -3,10 +3,12 @@ package gateway
import "errors"
var (
// ErrEndpointAlreadyExists 网关端点已存在
ErrEndpointAlreadyExists = errors.New("gateway: endpoint already exists")
// ErrCannotAddRunningEndpoint 无法添加一个正在运行的网关端点
ErrCannotAddRunningEndpoint = errors.New("gateway: cannot add a running endpoint")
// ErrEndpointNotExists 该名称下不存在任何端点
ErrEndpointNotExists = errors.New("gateway: endpoint not exists")
// ErrGatewayClosed 网关已关闭
ErrGatewayClosed = errors.New("gateway: gateway closed")
// ErrGatewayRunning 网关正在运行
ErrGatewayRunning = errors.New("gateway: gateway running")
// ErrConnectionNotFount 该端点下不存在该连接
ErrConnectionNotFount = errors.New("gateway: connection not found")
)

107
server/gateway/events.go Normal file
View File

@ -0,0 +1,107 @@
package gateway
import (
"github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/utils/slice"
)
type (
ConnectionOpenedEventHandle func(gateway *Gateway, conn *server.Conn)
ConnectionClosedEventHandle func(gateway *Gateway, conn *server.Conn)
ConnectionReceivePacketEventHandle func(gateway *Gateway, conn *server.Conn, packet []byte)
EndpointConnectOpenedEventHandle func(gateway *Gateway, endpoint *Endpoint)
EndpointConnectClosedEventHandle func(gateway *Gateway, endpoint *Endpoint)
EndpointConnectReceivePacketEventHandle func(gateway *Gateway, endpoint *Endpoint, conn *server.Conn, packet []byte)
)
func newEvents() *events {
return &events{
connectionOpenedEventHandles: slice.NewPriority[ConnectionOpenedEventHandle](),
connectionClosedEventHandles: slice.NewPriority[ConnectionClosedEventHandle](),
connectionReceivePacketEventHandles: slice.NewPriority[ConnectionReceivePacketEventHandle](),
endpointConnectOpenedEventHandles: slice.NewPriority[EndpointConnectOpenedEventHandle](),
endpointConnectClosedEventHandles: slice.NewPriority[EndpointConnectClosedEventHandle](),
endpointConnectReceivePacketEventHandles: slice.NewPriority[EndpointConnectReceivePacketEventHandle](),
}
}
type events struct {
connectionOpenedEventHandles *slice.Priority[ConnectionOpenedEventHandle]
connectionClosedEventHandles *slice.Priority[ConnectionClosedEventHandle]
connectionReceivePacketEventHandles *slice.Priority[ConnectionReceivePacketEventHandle]
endpointConnectOpenedEventHandles *slice.Priority[EndpointConnectOpenedEventHandle]
endpointConnectClosedEventHandles *slice.Priority[EndpointConnectClosedEventHandle]
endpointConnectReceivePacketEventHandles *slice.Priority[EndpointConnectReceivePacketEventHandle]
}
// RegConnectionOpenedEventHandle 注册客户端连接打开事件处理函数
func (slf *events) RegConnectionOpenedEventHandle(handle ConnectionOpenedEventHandle, priority ...int) {
slf.connectionOpenedEventHandles.Append(handle, slice.GetValue(priority, 0))
}
func (slf *events) OnConnectionOpenedEvent(gateway *Gateway, conn *server.Conn) {
slf.connectionOpenedEventHandles.RangeValue(func(index int, value ConnectionOpenedEventHandle) bool {
value(gateway, conn)
return true
})
}
// RegConnectionClosedEventHandle 注册客户端连接关闭事件处理函数
func (slf *events) RegConnectionClosedEventHandle(handle ConnectionClosedEventHandle, priority ...int) {
slf.connectionClosedEventHandles.Append(handle, slice.GetValue(priority, 0))
}
func (slf *events) OnConnectionClosedEvent(gateway *Gateway, conn *server.Conn) {
slf.connectionClosedEventHandles.RangeValue(func(index int, value ConnectionClosedEventHandle) bool {
value(gateway, conn)
return true
})
}
// RegConnectionReceivePacketEventHandle 注册客户端连接接收数据包事件处理函数
func (slf *events) RegConnectionReceivePacketEventHandle(handle ConnectionReceivePacketEventHandle, priority ...int) {
slf.connectionReceivePacketEventHandles.Append(handle, slice.GetValue(priority, 0))
}
func (slf *events) OnConnectionReceivePacketEvent(gateway *Gateway, conn *server.Conn, packet []byte) {
slf.connectionReceivePacketEventHandles.RangeValue(func(index int, value ConnectionReceivePacketEventHandle) bool {
value(gateway, conn, packet)
return true
})
}
// RegEndpointConnectOpenedEventHandle 注册端点连接打开事件处理函数
func (slf *events) RegEndpointConnectOpenedEventHandle(handle EndpointConnectOpenedEventHandle, priority ...int) {
slf.endpointConnectOpenedEventHandles.Append(handle, slice.GetValue(priority, 0))
}
func (slf *events) OnEndpointConnectOpenedEvent(gateway *Gateway, endpoint *Endpoint) {
slf.endpointConnectOpenedEventHandles.RangeValue(func(index int, value EndpointConnectOpenedEventHandle) bool {
value(gateway, endpoint)
return true
})
}
// RegEndpointConnectClosedEventHandle 注册端点连接关闭事件处理函数
func (slf *events) RegEndpointConnectClosedEventHandle(handle EndpointConnectClosedEventHandle, priority ...int) {
slf.endpointConnectClosedEventHandles.Append(handle, slice.GetValue(priority, 0))
}
func (slf *events) OnEndpointConnectClosedEvent(gateway *Gateway, endpoint *Endpoint) {
slf.endpointConnectClosedEventHandles.RangeValue(func(index int, value EndpointConnectClosedEventHandle) bool {
value(gateway, endpoint)
return true
})
}
// RegEndpointConnectReceivePacketEventHandle 注册端点连接接收数据包事件处理函数
func (slf *events) RegEndpointConnectReceivePacketEventHandle(handle EndpointConnectReceivePacketEventHandle, priority ...int) {
slf.endpointConnectReceivePacketEventHandles.Append(handle, slice.GetValue(priority, 0))
}
func (slf *events) OnEndpointConnectReceivePacketEvent(gateway *Gateway, endpoint *Endpoint, conn *server.Conn, packet []byte) {
slf.endpointConnectReceivePacketEventHandles.RangeValue(func(index int, value EndpointConnectReceivePacketEventHandle) bool {
value(gateway, endpoint, conn, packet)
return true
})
}

View File

@ -2,14 +2,28 @@ package gateway
import (
"github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/utils/random"
"math"
"sync"
"time"
)
type (
// EndpointSelector 端点选择器,用于从多个端点中选择一个可用的端点,如果没有可用的端点则返回 nil
EndpointSelector func(endpoints []*Endpoint) *Endpoint
)
// NewGateway 基于 server.Server 创建网关服务器
func NewGateway(srv *server.Server, options ...Option) *Gateway {
// - behaviorController 行为控制函数决定了客户端与网关服务器建立连接及接收数据包后的行为
func NewGateway(srv *server.Server, scanner Scanner, options ...Option) *Gateway {
gateway := &Gateway{
events: newEvents(),
srv: srv,
EndpointManager: NewEndpointManager(),
scanner: scanner,
es: make(map[string]map[string]*Endpoint),
ess: func(endpoints []*Endpoint) *Endpoint {
return endpoints[random.Int(0, len(endpoints)-1)]
},
}
for _, option := range options {
option(gateway)
@ -19,46 +33,109 @@ func NewGateway(srv *server.Server, options ...Option) *Gateway {
// Gateway 网关
type Gateway struct {
*EndpointManager // 端点管理器
*events
srv *server.Server // 网关服务器核心
scanner Scanner // 端点扫描器
es map[string]map[string]*Endpoint // 端点列表
esm sync.Mutex // 端点列表锁
ess EndpointSelector // 端点选择器
closed bool // 网关是否已关闭
running bool // 网关是否正在运行
}
// Run 运行网关
func (slf *Gateway) Run(addr string) error {
slf.srv.RegConnectionOpenedEvent(slf.onConnectionOpened, math.MinInt)
slf.srv.RegConnectionReceivePacketEvent(slf.onConnectionReceivePacket, math.MinInt)
return slf.srv.Run(addr)
if slf.closed {
return ErrGatewayClosed
}
if slf.running {
return ErrGatewayRunning
}
slf.srv.RegStartFinishEvent(func(srv *server.Server) {
go func() {
for !slf.closed {
endpoints, err := slf.scanner.GetEndpoints()
if err != nil {
continue
}
slf.esm.Lock()
for _, endpoint := range endpoints {
es, exist := slf.es[endpoint.GetName()]
if !exist {
es = make(map[string]*Endpoint)
slf.es[endpoint.GetName()] = es
}
e, exist := es[endpoint.GetAddress()]
if !exist {
e = endpoint
es[endpoint.GetAddress()] = e
go e.connect(slf)
}
}
slf.esm.Unlock()
time.Sleep(slf.scanner.GetInterval())
}
}()
}, math.MinInt)
slf.srv.RegStopEvent(func(srv *server.Server) {
slf.Shutdown()
}, math.MinInt)
slf.srv.RegConnectionOpenedEvent(func(srv *server.Server, conn *server.Conn) {
slf.OnConnectionOpenedEvent(slf, conn)
}, math.MinInt)
slf.srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) {
slf.OnConnectionClosedEvent(slf, conn)
}, math.MinInt)
slf.srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
slf.OnConnectionReceivePacketEvent(slf, conn, packet)
}, math.MinInt)
slf.running = true
if err := slf.srv.Run(addr); err != nil {
return err
}
slf.running = false
return nil
}
// Shutdown 关闭网关
func (slf *Gateway) Shutdown() {
if !slf.closed {
return
}
slf.closed = true
slf.srv.Shutdown()
}
func (slf *Gateway) onConnectionOpened(srv *server.Server, conn *server.Conn) {
endpoint, err := slf.GetEndpoint("test", conn.GetID())
if err != nil {
conn.Close()
return
}
endpoint.Link(conn)
// Server 获取网关服务器核心
func (slf *Gateway) Server() *server.Server {
return slf.srv
}
// onConnectionReceivePacket 连接接收数据包事件
func (slf *Gateway) onConnectionReceivePacket(srv *server.Server, conn *server.Conn, packet []byte) {
endpoint, err := slf.GetEndpoint("test", conn.GetID())
if err != nil {
conn.Close()
return
// GetEndpoint 获取一个可用的端点
// - name: 端点名称
func (slf *Gateway) GetEndpoint(name string) (*Endpoint, error) {
slf.esm.Lock()
endpoints, exist := slf.es[name]
if !exist || len(endpoints) == 0 {
delete(slf.es, name)
slf.esm.Unlock()
return nil, ErrEndpointNotExists
}
packet, err = MarshalGatewayOutPacket(conn.GetID(), packet)
if err != nil {
conn.Close()
return
}
if conn.IsWebsocket() {
endpoint.WriteWS(conn.GetWST(), packet)
} else {
endpoint.Write(packet)
var available = make([]*Endpoint, 0, len(endpoints))
for _, e := range endpoints {
if e.state > 0 {
available = append(available, e)
}
}
slf.esm.Unlock()
if len(available) == 0 {
return nil, ErrEndpointNotExists
}
endpoint := slf.ess(available)
if endpoint == nil {
return nil, ErrEndpointNotExists
}
return endpoint, nil
}

View File

@ -8,6 +8,19 @@ import (
"time"
)
type Scanner struct {
}
func (slf *Scanner) GetEndpoints() ([]*gateway.Endpoint, error) {
return []*gateway.Endpoint{
gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889")),
}, nil
}
func (slf *Scanner) GetInterval() time.Duration {
return time.Second
}
func TestGateway_RunEndpointServer(t *testing.T) {
srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3))
srv.RegConnectionPacketPreprocessEvent(func(srv *server.Server, conn *server.Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) {
@ -39,13 +52,16 @@ func TestGateway_RunEndpointServer(t *testing.T) {
}
func TestGateway_Run(t *testing.T) {
srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3))
gw := gateway.NewGateway(srv)
srv.RegStartFinishEvent(func(srv *server.Server) {
if err := gw.AddEndpoint(gateway.NewEndpoint(gw, "test", client.NewWebsocket("ws://127.0.0.1:8889"))); err != nil {
panic(err)
gw := gateway.NewGateway(server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)), new(Scanner))
gw.RegConnectionReceivePacketEventHandle(func(gateway *gateway.Gateway, conn *server.Conn, packet []byte) {
endpoint, err := gateway.GetEndpoint("test")
if err == nil {
endpoint.Forward(conn, packet)
}
})
gw.RegEndpointConnectReceivePacketEventHandle(func(gateway *gateway.Gateway, endpoint *gateway.Endpoint, conn *server.Conn, packet []byte) {
conn.Write(packet)
})
if err := gw.Run(":8888"); err != nil {
panic(err)
}

View File

@ -5,8 +5,8 @@ type Option func(gateway *Gateway)
// WithEndpointSelector 设置端点选择器
// - 默认情况下,网关会随机选择一个端点作为目标,如果需要自定义端点选择器,可以通过该选项设置
func WithEndpointSelector(selector func([]*Endpoint) *Endpoint) Option {
func WithEndpointSelector(selector EndpointSelector) Option {
return func(gateway *Gateway) {
gateway.EndpointManager.selector = selector
gateway.ess = selector
}
}

11
server/gateway/scanner.go Normal file
View File

@ -0,0 +1,11 @@
package gateway
import "time"
// Scanner 端点扫描器
type Scanner interface {
// GetEndpoints 获取端点列表
GetEndpoints() ([]*Endpoint, error)
// GetInterval 获取扫描间隔
GetInterval() time.Duration
}