diff --git a/server/gateway/endpoint.go b/server/gateway/endpoint.go index aca3608..54c2a54 100644 --- a/server/gateway/endpoint.go +++ b/server/gateway/endpoint.go @@ -4,16 +4,16 @@ import ( "github.com/alphadose/haxmap" "github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server/client" + "github.com/kercylan98/minotaur/utils/log" "time" ) // NewEndpoint 创建网关端点 -func NewEndpoint(gateway *Gateway, name string, client *client.Client, options ...EndpointOption) *Endpoint { +func NewEndpoint(name string, cli *client.Client, options ...EndpointOption) *Endpoint { endpoint := &Endpoint{ - gateway: gateway, - client: client, + client: cli, name: name, - address: client.GetServerAddr(), + address: cli.GetServerAddr(), connections: haxmap.New[string, *server.Conn](), } for _, option := range options { @@ -24,8 +24,6 @@ func NewEndpoint(gateway *Gateway, name string, client *client.Client, options . return 1 / (1 + 1.5*time.Duration(costUnixNano).Seconds()) } } - endpoint.client.RegConnectionClosedEvent(endpoint.onConnectionClosed) - endpoint.client.RegConnectionReceivePacketEvent(endpoint.onConnectionReceivePacket) return endpoint } @@ -36,34 +34,42 @@ type Endpoint struct { name string // 端点名称 address string // 端点地址 state float64 // 端点健康值(0为不可用,越高越优) - offline bool // 离线 evaluator func(costUnixNano float64) float64 // 端点健康值评估函数 - connections *haxmap.Map[string, *server.Conn] // 连接列表 + connections *haxmap.Map[string, *server.Conn] // 被该端点转发的连接列表 } -// Link 连接端点 -func (slf *Endpoint) Link(conn *server.Conn) { - slf.connections.Set(conn.GetID(), conn) -} - -// Unlink 断开连接 -func (slf *Endpoint) Unlink(conn *server.Conn) { - slf.connections.Del(conn.GetID()) -} - -// GetLink 获取连接 -func (slf *Endpoint) GetLink(id string) *server.Conn { - conn, _ := slf.connections.Get(id) - return conn -} - -// Offline 离线 -func (slf *Endpoint) Offline() { - slf.offline = true -} - -// Connect 连接端点 -func (slf *Endpoint) Connect() { +// connect 连接端点 +func (slf *Endpoint) connect(gateway *Gateway) { + slf.gateway = gateway + slf.client.RegConnectionOpenedEvent(func(conn *client.Client) { + slf.gateway.OnEndpointConnectOpenedEvent(slf.gateway, slf) + }) + slf.client.RegConnectionClosedEvent(func(conn *client.Client, err any) { + slf.gateway.OnEndpointConnectClosedEvent(slf.gateway, slf) + for { + cur := time.Now().UnixNano() + if err := slf.client.Run(); err == nil { + slf.state = slf.evaluator(float64(time.Now().UnixNano() - cur)) + break + } + time.Sleep(1000 * time.Millisecond) + } + }) + slf.client.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) { + addr, sendTime, packet, err := UnmarshalGatewayInPacket(packet) + if err != nil { + log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.Err(err)) + return + } + slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime)) + c, ok := slf.connections.Get(addr) + if !ok { + log.Error("Endpoint", log.String("Action", "ReceivePacket"), log.String("Name", slf.name), log.String("Addr", slf.address), log.String("ConnAddr", addr), log.Err(ErrConnectionNotFount)) + return + } + c.SetWST(wst) + slf.gateway.OnEndpointConnectReceivePacketEvent(slf.gateway, slf, c, packet) + }) for { cur := time.Now().UnixNano() if err := slf.client.Run(); err == nil { @@ -74,33 +80,46 @@ func (slf *Endpoint) Connect() { } } -// Write 写入数据 -func (slf *Endpoint) Write(packet []byte, callback ...func(err error)) { - slf.client.Write(packet, callback...) +// GetName 获取端点名称 +func (slf *Endpoint) GetName() string { + return slf.name } -// WriteWS 写入 websocket 数据 -func (slf *Endpoint) WriteWS(wst int, packet []byte, callback ...func(err error)) { - slf.client.WriteWS(wst, packet, callback...) +// GetAddress 获取端点地址 +func (slf *Endpoint) GetAddress() string { + return slf.address } -// onConnectionClosed 与端点连接断开事件 -func (slf *Endpoint) onConnectionClosed(conn *client.Client, err any) { - if !slf.offline { - go slf.Connect() - } +// GetState 获取端点健康值 +func (slf *Endpoint) GetState() float64 { + return slf.state } -// onConnectionReceivePacket 接收到来自端点的数据包事件 -func (slf *Endpoint) onConnectionReceivePacket(conn *client.Client, wst int, packet []byte) { - addr, sendTime, packet, err := UnmarshalGatewayInPacket(packet) +// Forward 转发数据包到该端点 +// - 端点在处理数据包时,应区分数据包为普通直连数据包还是网关数据包。可通过 UnmarshalGatewayOutPacket 进行数据包解析,当解析失败且无其他数据包协议时,可认为该数据包为普通直连数据包。 +func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func(err error)) { + var err error + packet, err = MarshalGatewayOutPacket(conn.GetID(), packet) if err != nil { - panic(err) - } - slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime)) - cli := slf.GetLink(addr) - if cli == nil { + if len(callback) > 0 { + callback[0](err) + } return } - cli.SetWST(wst).Write(packet) + + cb := func(err error) { + if len(callback) > 0 { + callback[0](err) + } + if err != nil { + slf.connections.Del(conn.GetID()) + } else { + slf.connections.Set(conn.GetID(), conn) + } + } + if conn.IsWebsocket() { + slf.client.WriteWS(conn.GetWST(), packet, cb) + } else { + slf.client.Write(packet, cb) + } } diff --git a/server/gateway/endpoint_manager.go b/server/gateway/endpoint_manager.go deleted file mode 100644 index 7695aa6..0000000 --- a/server/gateway/endpoint_manager.go +++ /dev/null @@ -1,96 +0,0 @@ -package gateway - -import ( - "github.com/alphadose/haxmap" - "github.com/kercylan98/minotaur/utils/concurrent" - "github.com/kercylan98/minotaur/utils/random" -) - -// NewEndpointManager 创建网关端点管理器 -func NewEndpointManager() *EndpointManager { - em := &EndpointManager{ - endpoints: concurrent.NewBalanceMap[string, []*Endpoint](), - memory: haxmap.New[string, *Endpoint](), - selector: func(endpoints []*Endpoint) *Endpoint { - return endpoints[random.Int(0, len(endpoints)-1)] - }, - } - return em -} - -// EndpointManager 网关端点管理器 -type EndpointManager struct { - endpoints *concurrent.BalanceMap[string, []*Endpoint] - memory *haxmap.Map[string, *Endpoint] - selector func([]*Endpoint) *Endpoint -} - -// GetEndpoint 获取端点 -// - name: 端点名称 -// - id: 使用端点的连接标识 -func (slf *EndpointManager) GetEndpoint(name, id string) (*Endpoint, error) { - endpoint, exist := slf.memory.Get(id) - if exist { - return endpoint, nil - } - slf.endpoints.Atom(func(m map[string][]*Endpoint) { - endpoints, exist := m[name] - if !exist { - return - } - if len(endpoints) == 0 { - return - } - var available = make([]*Endpoint, 0, len(endpoints)) - for _, e := range endpoints { - if !e.offline && e.state > 0 { - available = append(available, e) - } - } - if len(available) == 0 { - return - } - endpoint = slf.selector(available) - }) - if endpoint == nil { - return nil, ErrEndpointNotExists - } - slf.memory.Set(id, endpoint) - return endpoint, nil -} - -// AddEndpoint 添加端点 -func (slf *EndpointManager) AddEndpoint(endpoint *Endpoint) error { - if endpoint.client.IsConnected() { - return ErrCannotAddRunningEndpoint - } - for _, e := range slf.endpoints.Get(endpoint.name) { - if e.address == endpoint.address { - return ErrEndpointAlreadyExists - } - } - go endpoint.Connect() - slf.endpoints.Atom(func(m map[string][]*Endpoint) { - m[endpoint.name] = append(m[endpoint.name], endpoint) - }) - return nil -} - -// RemoveEndpoint 移除端点 -func (slf *EndpointManager) RemoveEndpoint(endpoint *Endpoint) error { - slf.endpoints.Atom(func(m map[string][]*Endpoint) { - var endpoints []*Endpoint - endpoints, exist := m[endpoint.name] - if !exist { - return - } - for i, e := range endpoints { - if e.address == endpoint.address { - endpoints = append(endpoints[:i], endpoints[i+1:]...) - break - } - } - m[endpoint.name] = endpoints - }) - return nil -} diff --git a/server/gateway/errors.go b/server/gateway/errors.go index 1242f01..ed518af 100644 --- a/server/gateway/errors.go +++ b/server/gateway/errors.go @@ -3,10 +3,12 @@ package gateway import "errors" var ( - // ErrEndpointAlreadyExists 网关端点已存在 - ErrEndpointAlreadyExists = errors.New("gateway: endpoint already exists") - // ErrCannotAddRunningEndpoint 无法添加一个正在运行的网关端点 - ErrCannotAddRunningEndpoint = errors.New("gateway: cannot add a running endpoint") // ErrEndpointNotExists 该名称下不存在任何端点 ErrEndpointNotExists = errors.New("gateway: endpoint not exists") + // ErrGatewayClosed 网关已关闭 + ErrGatewayClosed = errors.New("gateway: gateway closed") + // ErrGatewayRunning 网关正在运行 + ErrGatewayRunning = errors.New("gateway: gateway running") + // ErrConnectionNotFount 该端点下不存在该连接 + ErrConnectionNotFount = errors.New("gateway: connection not found") ) diff --git a/server/gateway/events.go b/server/gateway/events.go new file mode 100644 index 0000000..edfa1ce --- /dev/null +++ b/server/gateway/events.go @@ -0,0 +1,107 @@ +package gateway + +import ( + "github.com/kercylan98/minotaur/server" + "github.com/kercylan98/minotaur/utils/slice" +) + +type ( + ConnectionOpenedEventHandle func(gateway *Gateway, conn *server.Conn) + ConnectionClosedEventHandle func(gateway *Gateway, conn *server.Conn) + ConnectionReceivePacketEventHandle func(gateway *Gateway, conn *server.Conn, packet []byte) + EndpointConnectOpenedEventHandle func(gateway *Gateway, endpoint *Endpoint) + EndpointConnectClosedEventHandle func(gateway *Gateway, endpoint *Endpoint) + EndpointConnectReceivePacketEventHandle func(gateway *Gateway, endpoint *Endpoint, conn *server.Conn, packet []byte) +) + +func newEvents() *events { + return &events{ + connectionOpenedEventHandles: slice.NewPriority[ConnectionOpenedEventHandle](), + connectionClosedEventHandles: slice.NewPriority[ConnectionClosedEventHandle](), + connectionReceivePacketEventHandles: slice.NewPriority[ConnectionReceivePacketEventHandle](), + endpointConnectOpenedEventHandles: slice.NewPriority[EndpointConnectOpenedEventHandle](), + endpointConnectClosedEventHandles: slice.NewPriority[EndpointConnectClosedEventHandle](), + endpointConnectReceivePacketEventHandles: slice.NewPriority[EndpointConnectReceivePacketEventHandle](), + } +} + +type events struct { + connectionOpenedEventHandles *slice.Priority[ConnectionOpenedEventHandle] + connectionClosedEventHandles *slice.Priority[ConnectionClosedEventHandle] + connectionReceivePacketEventHandles *slice.Priority[ConnectionReceivePacketEventHandle] + endpointConnectOpenedEventHandles *slice.Priority[EndpointConnectOpenedEventHandle] + endpointConnectClosedEventHandles *slice.Priority[EndpointConnectClosedEventHandle] + endpointConnectReceivePacketEventHandles *slice.Priority[EndpointConnectReceivePacketEventHandle] +} + +// RegConnectionOpenedEventHandle 注册客户端连接打开事件处理函数 +func (slf *events) RegConnectionOpenedEventHandle(handle ConnectionOpenedEventHandle, priority ...int) { + slf.connectionOpenedEventHandles.Append(handle, slice.GetValue(priority, 0)) +} + +func (slf *events) OnConnectionOpenedEvent(gateway *Gateway, conn *server.Conn) { + slf.connectionOpenedEventHandles.RangeValue(func(index int, value ConnectionOpenedEventHandle) bool { + value(gateway, conn) + return true + }) +} + +// RegConnectionClosedEventHandle 注册客户端连接关闭事件处理函数 +func (slf *events) RegConnectionClosedEventHandle(handle ConnectionClosedEventHandle, priority ...int) { + slf.connectionClosedEventHandles.Append(handle, slice.GetValue(priority, 0)) +} + +func (slf *events) OnConnectionClosedEvent(gateway *Gateway, conn *server.Conn) { + slf.connectionClosedEventHandles.RangeValue(func(index int, value ConnectionClosedEventHandle) bool { + value(gateway, conn) + return true + }) +} + +// RegConnectionReceivePacketEventHandle 注册客户端连接接收数据包事件处理函数 +func (slf *events) RegConnectionReceivePacketEventHandle(handle ConnectionReceivePacketEventHandle, priority ...int) { + slf.connectionReceivePacketEventHandles.Append(handle, slice.GetValue(priority, 0)) +} + +func (slf *events) OnConnectionReceivePacketEvent(gateway *Gateway, conn *server.Conn, packet []byte) { + slf.connectionReceivePacketEventHandles.RangeValue(func(index int, value ConnectionReceivePacketEventHandle) bool { + value(gateway, conn, packet) + return true + }) +} + +// RegEndpointConnectOpenedEventHandle 注册端点连接打开事件处理函数 +func (slf *events) RegEndpointConnectOpenedEventHandle(handle EndpointConnectOpenedEventHandle, priority ...int) { + slf.endpointConnectOpenedEventHandles.Append(handle, slice.GetValue(priority, 0)) +} + +func (slf *events) OnEndpointConnectOpenedEvent(gateway *Gateway, endpoint *Endpoint) { + slf.endpointConnectOpenedEventHandles.RangeValue(func(index int, value EndpointConnectOpenedEventHandle) bool { + value(gateway, endpoint) + return true + }) +} + +// RegEndpointConnectClosedEventHandle 注册端点连接关闭事件处理函数 +func (slf *events) RegEndpointConnectClosedEventHandle(handle EndpointConnectClosedEventHandle, priority ...int) { + slf.endpointConnectClosedEventHandles.Append(handle, slice.GetValue(priority, 0)) +} + +func (slf *events) OnEndpointConnectClosedEvent(gateway *Gateway, endpoint *Endpoint) { + slf.endpointConnectClosedEventHandles.RangeValue(func(index int, value EndpointConnectClosedEventHandle) bool { + value(gateway, endpoint) + return true + }) +} + +// RegEndpointConnectReceivePacketEventHandle 注册端点连接接收数据包事件处理函数 +func (slf *events) RegEndpointConnectReceivePacketEventHandle(handle EndpointConnectReceivePacketEventHandle, priority ...int) { + slf.endpointConnectReceivePacketEventHandles.Append(handle, slice.GetValue(priority, 0)) +} + +func (slf *events) OnEndpointConnectReceivePacketEvent(gateway *Gateway, endpoint *Endpoint, conn *server.Conn, packet []byte) { + slf.endpointConnectReceivePacketEventHandles.RangeValue(func(index int, value EndpointConnectReceivePacketEventHandle) bool { + value(gateway, endpoint, conn, packet) + return true + }) +} diff --git a/server/gateway/gateway.go b/server/gateway/gateway.go index 13a4018..ed9d9be 100644 --- a/server/gateway/gateway.go +++ b/server/gateway/gateway.go @@ -2,14 +2,28 @@ package gateway import ( "github.com/kercylan98/minotaur/server" + "github.com/kercylan98/minotaur/utils/random" "math" + "sync" + "time" +) + +type ( + // EndpointSelector 端点选择器,用于从多个端点中选择一个可用的端点,如果没有可用的端点则返回 nil + EndpointSelector func(endpoints []*Endpoint) *Endpoint ) // NewGateway 基于 server.Server 创建网关服务器 -func NewGateway(srv *server.Server, options ...Option) *Gateway { +// - behaviorController 行为控制函数决定了客户端与网关服务器建立连接及接收数据包后的行为 +func NewGateway(srv *server.Server, scanner Scanner, options ...Option) *Gateway { gateway := &Gateway{ - srv: srv, - EndpointManager: NewEndpointManager(), + events: newEvents(), + srv: srv, + scanner: scanner, + es: make(map[string]map[string]*Endpoint), + ess: func(endpoints []*Endpoint) *Endpoint { + return endpoints[random.Int(0, len(endpoints)-1)] + }, } for _, option := range options { option(gateway) @@ -19,46 +33,109 @@ func NewGateway(srv *server.Server, options ...Option) *Gateway { // Gateway 网关 type Gateway struct { - *EndpointManager // 端点管理器 - srv *server.Server // 网关服务器核心 + *events + srv *server.Server // 网关服务器核心 + scanner Scanner // 端点扫描器 + es map[string]map[string]*Endpoint // 端点列表 + esm sync.Mutex // 端点列表锁 + ess EndpointSelector // 端点选择器 + closed bool // 网关是否已关闭 + running bool // 网关是否正在运行 } // Run 运行网关 func (slf *Gateway) Run(addr string) error { - slf.srv.RegConnectionOpenedEvent(slf.onConnectionOpened, math.MinInt) - slf.srv.RegConnectionReceivePacketEvent(slf.onConnectionReceivePacket, math.MinInt) - return slf.srv.Run(addr) + if slf.closed { + return ErrGatewayClosed + } + if slf.running { + return ErrGatewayRunning + } + slf.srv.RegStartFinishEvent(func(srv *server.Server) { + go func() { + for !slf.closed { + endpoints, err := slf.scanner.GetEndpoints() + if err != nil { + continue + } + slf.esm.Lock() + for _, endpoint := range endpoints { + es, exist := slf.es[endpoint.GetName()] + if !exist { + es = make(map[string]*Endpoint) + slf.es[endpoint.GetName()] = es + } + e, exist := es[endpoint.GetAddress()] + if !exist { + e = endpoint + es[endpoint.GetAddress()] = e + go e.connect(slf) + } + } + slf.esm.Unlock() + time.Sleep(slf.scanner.GetInterval()) + } + }() + }, math.MinInt) + slf.srv.RegStopEvent(func(srv *server.Server) { + slf.Shutdown() + }, math.MinInt) + slf.srv.RegConnectionOpenedEvent(func(srv *server.Server, conn *server.Conn) { + slf.OnConnectionOpenedEvent(slf, conn) + }, math.MinInt) + slf.srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) { + slf.OnConnectionClosedEvent(slf, conn) + }, math.MinInt) + slf.srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { + slf.OnConnectionReceivePacketEvent(slf, conn, packet) + }, math.MinInt) + slf.running = true + if err := slf.srv.Run(addr); err != nil { + return err + } + slf.running = false + return nil } // Shutdown 关闭网关 func (slf *Gateway) Shutdown() { + if !slf.closed { + return + } + slf.closed = true slf.srv.Shutdown() } -func (slf *Gateway) onConnectionOpened(srv *server.Server, conn *server.Conn) { - endpoint, err := slf.GetEndpoint("test", conn.GetID()) - if err != nil { - conn.Close() - return - } - endpoint.Link(conn) +// Server 获取网关服务器核心 +func (slf *Gateway) Server() *server.Server { + return slf.srv } -// onConnectionReceivePacket 连接接收数据包事件 -func (slf *Gateway) onConnectionReceivePacket(srv *server.Server, conn *server.Conn, packet []byte) { - endpoint, err := slf.GetEndpoint("test", conn.GetID()) - if err != nil { - conn.Close() - return +// GetEndpoint 获取一个可用的端点 +// - name: 端点名称 +func (slf *Gateway) GetEndpoint(name string) (*Endpoint, error) { + slf.esm.Lock() + endpoints, exist := slf.es[name] + if !exist || len(endpoints) == 0 { + delete(slf.es, name) + slf.esm.Unlock() + return nil, ErrEndpointNotExists } - packet, err = MarshalGatewayOutPacket(conn.GetID(), packet) - if err != nil { - conn.Close() - return + + var available = make([]*Endpoint, 0, len(endpoints)) + for _, e := range endpoints { + if e.state > 0 { + available = append(available, e) + } } - if conn.IsWebsocket() { - endpoint.WriteWS(conn.GetWST(), packet) - } else { - endpoint.Write(packet) + slf.esm.Unlock() + if len(available) == 0 { + return nil, ErrEndpointNotExists } + + endpoint := slf.ess(available) + if endpoint == nil { + return nil, ErrEndpointNotExists + } + return endpoint, nil } diff --git a/server/gateway/gateway_test.go b/server/gateway/gateway_test.go index 43ce56c..ba63006 100644 --- a/server/gateway/gateway_test.go +++ b/server/gateway/gateway_test.go @@ -8,6 +8,19 @@ import ( "time" ) +type Scanner struct { +} + +func (slf *Scanner) GetEndpoints() ([]*gateway.Endpoint, error) { + return []*gateway.Endpoint{ + gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889")), + }, nil +} + +func (slf *Scanner) GetInterval() time.Duration { + return time.Second +} + func TestGateway_RunEndpointServer(t *testing.T) { srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) srv.RegConnectionPacketPreprocessEvent(func(srv *server.Server, conn *server.Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) { @@ -39,13 +52,16 @@ func TestGateway_RunEndpointServer(t *testing.T) { } func TestGateway_Run(t *testing.T) { - srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) - gw := gateway.NewGateway(srv) - srv.RegStartFinishEvent(func(srv *server.Server) { - if err := gw.AddEndpoint(gateway.NewEndpoint(gw, "test", client.NewWebsocket("ws://127.0.0.1:8889"))); err != nil { - panic(err) + gw := gateway.NewGateway(server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)), new(Scanner)) + gw.RegConnectionReceivePacketEventHandle(func(gateway *gateway.Gateway, conn *server.Conn, packet []byte) { + endpoint, err := gateway.GetEndpoint("test") + if err == nil { + endpoint.Forward(conn, packet) } }) + gw.RegEndpointConnectReceivePacketEventHandle(func(gateway *gateway.Gateway, endpoint *gateway.Endpoint, conn *server.Conn, packet []byte) { + conn.Write(packet) + }) if err := gw.Run(":8888"); err != nil { panic(err) } diff --git a/server/gateway/options.go b/server/gateway/options.go index a467c30..5a011c9 100644 --- a/server/gateway/options.go +++ b/server/gateway/options.go @@ -5,8 +5,8 @@ type Option func(gateway *Gateway) // WithEndpointSelector 设置端点选择器 // - 默认情况下,网关会随机选择一个端点作为目标,如果需要自定义端点选择器,可以通过该选项设置 -func WithEndpointSelector(selector func([]*Endpoint) *Endpoint) Option { +func WithEndpointSelector(selector EndpointSelector) Option { return func(gateway *Gateway) { - gateway.EndpointManager.selector = selector + gateway.ess = selector } } diff --git a/server/gateway/scanner.go b/server/gateway/scanner.go new file mode 100644 index 0000000..b10350a --- /dev/null +++ b/server/gateway/scanner.go @@ -0,0 +1,11 @@ +package gateway + +import "time" + +// Scanner 端点扫描器 +type Scanner interface { + // GetEndpoints 获取端点列表 + GetEndpoints() ([]*Endpoint, error) + // GetInterval 获取扫描间隔 + GetInterval() time.Duration +}