diff --git a/server/gateway/endpoint.go b/server/gateway/endpoint.go new file mode 100644 index 0000000..07b6fae --- /dev/null +++ b/server/gateway/endpoint.go @@ -0,0 +1,62 @@ +package gateway + +import ( + "github.com/kercylan98/minotaur/server" + "github.com/kercylan98/minotaur/server/client" + "time" +) + +// NewEndpoint 创建网关端点 +func NewEndpoint(name, address string) *Endpoint { + endpoint := &Endpoint{ + client: client.NewWebsocket(address), + name: name, + address: address, + } + endpoint.client.RegConnectionClosedEvent(endpoint.onConnectionClosed) + endpoint.client.RegConnectionReceivePacketEvent(endpoint.onConnectionReceivePacket) + return endpoint +} + +// Endpoint 网关端点 +type Endpoint struct { + client *client.Websocket // 端点客户端 + name string // 端点名称 + address string // 端点地址 + state float64 // 端点健康值(0为不可用,越高越优) + offline bool // 离线 +} + +// Offline 离线 +func (slf *Endpoint) Offline() { + slf.offline = true +} + +// Connect 连接端点 +func (slf *Endpoint) Connect() { + for { + var now = time.Now() + if err := slf.client.Run(); err == nil { + slf.state = 1 - (time.Since(now).Seconds() / 10) + break + } + time.Sleep(100 * time.Millisecond) + } +} + +// Write 写入数据 +func (slf *Endpoint) Write(packet server.Packet) { + slf.client.Write(packet) +} + +func (slf *Endpoint) onConnectionClosed(conn *client.Websocket, err any) { + if !slf.offline { + go slf.Connect() + } +} + +func (slf *Endpoint) onConnectionReceivePacket(conn *client.Websocket, packet server.Packet) { + p := UnpackGatewayPacket(packet) + packet.Data = p.Data + conn.GetData(p.ConnID).(*server.Conn).Write(packet) +} diff --git a/server/gateway/endpoint_manager.go b/server/gateway/endpoint_manager.go new file mode 100644 index 0000000..cb92600 --- /dev/null +++ b/server/gateway/endpoint_manager.go @@ -0,0 +1,89 @@ +package gateway + +import ( + "github.com/kercylan98/minotaur/server" + "github.com/kercylan98/minotaur/utils/concurrent" + "github.com/kercylan98/minotaur/utils/slice" +) + +// NewEndpointManager 创建网关端点管理器 +func NewEndpointManager() *EndpointManager { + em := &EndpointManager{ + endpoints: concurrent.NewBalanceMap[string, []*Endpoint](), + memory: concurrent.NewBalanceMap[string, *Endpoint](), + } + return em +} + +// EndpointManager 网关端点管理器 +type EndpointManager struct { + endpoints *concurrent.BalanceMap[string, []*Endpoint] + memory *concurrent.BalanceMap[string, *Endpoint] +} + +// GetEndpoint 获取端点 +func (slf *EndpointManager) GetEndpoint(name string, conn *server.Conn) (*Endpoint, error) { + endpoint, exist := slf.memory.GetExist(conn.GetID()) + if exist { + return endpoint, nil + } + slf.endpoints.Atom(func(m map[string][]*Endpoint) { + endpoints, exist := m[name] + if !exist { + return + } + if len(endpoints) == 0 { + return + } + // 随机获取 + endpoints = slice.Copy(endpoints) + slice.Shuffle(endpoints) + for _, e := range endpoints { + if e.offline || e.state <= 0 { + continue + } + endpoint = e + } + }) + if endpoint == nil { + return nil, ErrEndpointNotExists + } + slf.memory.Set(conn.GetID(), endpoint) + return endpoint, nil +} + +// AddEndpoint 添加端点 +func (slf *EndpointManager) AddEndpoint(endpoint *Endpoint) error { + if endpoint.client.IsConnected() { + return ErrCannotAddRunningEndpoint + } + for _, e := range slf.endpoints.Get(endpoint.name) { + if e.address == endpoint.address { + return ErrEndpointAlreadyExists + } + } + go endpoint.Connect() + slf.endpoints.Atom(func(m map[string][]*Endpoint) { + m[endpoint.name] = append(m[endpoint.name], endpoint) + }) + return nil +} + +// RemoveEndpoint 移除端点 +func (slf *EndpointManager) RemoveEndpoint(endpoint *Endpoint) error { + slf.endpoints.Atom(func(m map[string][]*Endpoint) { + var endpoints []*Endpoint + endpoints, exist := m[endpoint.name] + if !exist { + return + } + for i, e := range endpoints { + if e.address == endpoint.address { + endpoints = append(endpoints[:i], endpoints[i+1:]...) + break + } + } + m[endpoint.name] = endpoints + }) + return nil +} diff --git a/server/gateway/errors.go b/server/gateway/errors.go new file mode 100644 index 0000000..1242f01 --- /dev/null +++ b/server/gateway/errors.go @@ -0,0 +1,12 @@ +package gateway + +import "errors" + +var ( + // ErrEndpointAlreadyExists 网关端点已存在 + ErrEndpointAlreadyExists = errors.New("gateway: endpoint already exists") + // ErrCannotAddRunningEndpoint 无法添加一个正在运行的网关端点 + ErrCannotAddRunningEndpoint = errors.New("gateway: cannot add a running endpoint") + // ErrEndpointNotExists 该名称下不存在任何端点 + ErrEndpointNotExists = errors.New("gateway: endpoint not exists") +) diff --git a/server/gateway/gateway.go b/server/gateway/gateway.go new file mode 100644 index 0000000..7c749a4 --- /dev/null +++ b/server/gateway/gateway.go @@ -0,0 +1,71 @@ +package gateway + +import ( + "github.com/kercylan98/minotaur/server" + "github.com/kercylan98/minotaur/utils/super" +) + +// NewGateway 基于 server.Server 创建网关服务器 +func NewGateway(srv *server.Server) *Gateway { + gateway := &Gateway{ + srv: srv, + EndpointManager: NewEndpointManager(), + } + return gateway +} + +// Gateway 网关 +type Gateway struct { + *EndpointManager // 端点管理器 + srv *server.Server // 网关服务器核心 +} + +// Run 运行网关 +func (slf *Gateway) Run(addr string) error { + slf.srv.RegConnectionOpenedEvent(slf.onConnectionOpened) + slf.srv.RegConnectionReceivePacketEvent(slf.onConnectionReceivePacket) + return slf.srv.Run(addr) +} + +// Shutdown 关闭网关 +func (slf *Gateway) Shutdown() { + slf.srv.Shutdown() +} + +// onConnectionOpened 连接打开事件 +func (slf *Gateway) onConnectionOpened(srv *server.Server, conn *server.Conn) { + endpoint, err := slf.GetEndpoint("test", conn) + if err != nil { + conn.Close() + return + } + endpoint.client.SetData(conn.GetID(), conn) + conn.SetData("endpoint", endpoint) +} + +// onConnectionReceivePacket 连接接收数据包事件 +func (slf *Gateway) onConnectionReceivePacket(srv *server.Server, conn *server.Conn, packet server.Packet) { + conn.GetData("endpoint").(*Endpoint).Write(PackGatewayPacket(conn.GetID(), packet.WebsocketType, packet.Data)) +} + +// PackGatewayPacket 打包网关数据包 +func PackGatewayPacket(connID string, websocketType int, data []byte) server.Packet { + var gatewayPacket = Packet{ + ConnID: connID, + WebsocketType: websocketType, + Data: data, + } + return server.Packet{ + WebsocketType: websocketType, + Data: super.MarshalJSON(&gatewayPacket), + } +} + +// UnpackGatewayPacket 解包网关数据包 +func UnpackGatewayPacket(packet server.Packet) Packet { + var gatewayPacket Packet + if err := super.UnmarshalJSON(packet.Data, &gatewayPacket); err != nil { + panic(err) + } + return gatewayPacket +} diff --git a/server/gateway/gateway_test.go b/server/gateway/gateway_test.go new file mode 100644 index 0000000..3d78e1d --- /dev/null +++ b/server/gateway/gateway_test.go @@ -0,0 +1,33 @@ +package gateway_test + +import ( + "fmt" + "github.com/kercylan98/minotaur/server" + gateway2 "github.com/kercylan98/minotaur/server/gateway" + "testing" +) + +func TestGateway_RunEndpointServer(t *testing.T) { + srv := server.New(server.NetworkWebsocket) + srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet server.Packet) { + p := gateway2.UnpackGatewayPacket(packet) + fmt.Println("endpoint receive packet", string(p.Data)) + conn.Write(packet) + }) + if err := srv.Run(":8889"); err != nil { + panic(err) + } +} + +func TestGateway_Run(t *testing.T) { + srv := server.New(server.NetworkWebsocket) + gw := gateway2.NewGateway(srv) + srv.RegStartFinishEvent(func(srv *server.Server) { + if err := gw.AddEndpoint(gateway2.NewEndpoint("test", "ws://127.0.0.1:8889")); err != nil { + panic(err) + } + }) + if err := gw.Run(":8888"); err != nil { + panic(err) + } +} diff --git a/server/gateway/packet.go b/server/gateway/packet.go new file mode 100644 index 0000000..761fd8b --- /dev/null +++ b/server/gateway/packet.go @@ -0,0 +1,7 @@ +package gateway + +type Packet struct { + ConnID string + WebsocketType int + Data []byte +}