From 6d5aa599d76ac3e297077781401e039df6562ec7 Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Fri, 1 Sep 2023 20:25:23 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20gateway=20=E6=94=AF=E6=8C=81=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=E4=B8=8E=E6=9F=90=E4=B8=80=E7=AB=AF=E7=82=B9=E4=BF=9D?= =?UTF-8?q?=E6=8C=81=E6=8C=81=E4=B9=85=E9=80=9A=E8=AE=AF=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=B0=86=E7=AB=AF=E7=82=B9=E7=9A=84=E6=89=80=E6=9C=89?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E5=88=87=E6=8D=A2=E5=88=B0=E5=8F=A6=E4=B8=80?= =?UTF-8?q?=E7=AB=AF=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/gateway/endpoint.go | 25 +++++++++++--------- server/gateway/gateway.go | 32 ++++++++++++++++++++++++- server/gateway/gateway_test.go | 43 ++++++++++++++++++++++++++++------ 3 files changed, 81 insertions(+), 19 deletions(-) diff --git a/server/gateway/endpoint.go b/server/gateway/endpoint.go index 30fde8a..ce74ea7 100644 --- a/server/gateway/endpoint.go +++ b/server/gateway/endpoint.go @@ -126,17 +126,6 @@ func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func( return } - cb := func(err error) { - if len(callback) > 0 { - callback[0](err) - } - if err != nil { - slf.connections.Del(conn.GetID()) - } else { - slf.connections.Set(conn.GetID(), conn) - } - } - var superior *client.Client var superiorCount = -1 for _, cli := range slf.client { @@ -147,6 +136,20 @@ func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func( } } + var cb = func(err error) { + if len(callback) > 0 { + callback[0](err) + } + if err != nil { + slf.connections.Del(conn.GetID()) + } else { + slf.connections.Set(conn.GetID(), conn) + slf.gateway.cceLock.Lock() + slf.gateway.cce[conn.GetID()] = slf + slf.gateway.cceLock.Unlock() + } + } + if conn.IsWebsocket() { superior.WriteWS(conn.GetWST(), packet, cb) } else { diff --git a/server/gateway/gateway.go b/server/gateway/gateway.go index ed9d9be..8d3569b 100644 --- a/server/gateway/gateway.go +++ b/server/gateway/gateway.go @@ -24,6 +24,7 @@ func NewGateway(srv *server.Server, scanner Scanner, options ...Option) *Gateway ess: func(endpoints []*Endpoint) *Endpoint { return endpoints[random.Int(0, len(endpoints)-1)] }, + cce: make(map[string]*Endpoint), } for _, option := range options { option(gateway) @@ -36,11 +37,13 @@ type Gateway struct { *events srv *server.Server // 网关服务器核心 scanner Scanner // 端点扫描器 - es map[string]map[string]*Endpoint // 端点列表 + es map[string]map[string]*Endpoint // 端点列表 [name][address] esm sync.Mutex // 端点列表锁 ess EndpointSelector // 端点选择器 closed bool // 网关是否已关闭 running bool // 网关是否正在运行 + cce map[string]*Endpoint // 连接当前连接的端点 [conn.ID] + cceLock sync.RWMutex // 连接当前连接的端点锁 } // Run 运行网关 @@ -139,3 +142,30 @@ func (slf *Gateway) GetEndpoint(name string) (*Endpoint, error) { } return endpoint, nil } + +// GetConnEndpoint 获取一个可用的端点,如果客户端已经连接到了某个端点,将优先返回该端点 +// - 当连接到的端点不可用或没有连接记录时,效果同 GetEndpoint 相同 +// - 当连接行为为有状态时,推荐使用该方法 +func (slf *Gateway) GetConnEndpoint(name string, conn *server.Conn) (*Endpoint, error) { + slf.cceLock.RLock() + endpoint, exist := slf.cce[conn.GetID()] + slf.cceLock.RUnlock() + if exist && endpoint.state > 0 { + return endpoint, nil + } + return slf.GetEndpoint(name) +} + +// SwitchEndpoint 将端点端点的所有连接切换到另一个端点 +func (slf *Gateway) SwitchEndpoint(source, dest *Endpoint) { + if source.name == dest.name && source.address == dest.address || source.state <= 0 || dest.state <= 0 { + return + } + slf.cceLock.Lock() + for id, endpoint := range slf.cce { + if endpoint == source { + slf.cce[id] = dest + } + } + slf.cceLock.Unlock() +} diff --git a/server/gateway/gateway_test.go b/server/gateway/gateway_test.go index 1bb2f56..435e0bc 100644 --- a/server/gateway/gateway_test.go +++ b/server/gateway/gateway_test.go @@ -1,10 +1,11 @@ package gateway_test import ( + "fmt" "github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server/client" "github.com/kercylan98/minotaur/server/gateway" - "github.com/kercylan98/minotaur/utils/log" + "github.com/kercylan98/minotaur/utils/super" "testing" "time" ) @@ -15,6 +16,7 @@ type Scanner struct { func (slf *Scanner) GetEndpoints() ([]*gateway.Endpoint, error) { return []*gateway.Endpoint{ gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889"), gateway.WithEndpointConnectionPoolSize(10)), + gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8890"), gateway.WithEndpointConnectionPoolSize(10)), }, nil } @@ -22,7 +24,7 @@ func (slf *Scanner) GetInterval() time.Duration { return time.Second } -func TestGateway_RunEndpointServer(t *testing.T) { +func TestGateway_RunEndpointServerA(t *testing.T) { srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) srv.RegConnectionPacketPreprocessEvent(func(srv *server.Server, conn *server.Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) { addr, packet, err := gateway.UnmarshalGatewayOutPacket(packet) @@ -45,20 +47,47 @@ func TestGateway_RunEndpointServer(t *testing.T) { return packet }) srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { - conn.Write(packet) - }) - srv.RegConnectionOpenedEvent(func(srv *server.Server, conn *server.Conn) { - log.Info("connection opened", log.String("conn", conn.GetID())) + conn.Write(super.StringToBytes(fmt.Sprintf("Endpoint A: %s", packet))) }) if err := srv.Run(":8889"); err != nil { panic(err) } } +func TestGateway_RunEndpointServerB(t *testing.T) { + srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) + srv.RegConnectionPacketPreprocessEvent(func(srv *server.Server, conn *server.Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) { + addr, packet, err := gateway.UnmarshalGatewayOutPacket(packet) + if err != nil { + // 非网关的普通数据包 + return + } + usePacket(packet) + conn.SetMessageData("gw-addr", addr) + }) + srv.RegConnectionWritePacketBeforeEvent(func(srv *server.Server, conn *server.Conn, packet []byte) []byte { + addr, ok := conn.GetMessageData("gw-addr").(string) + if !ok { + return packet + } + packet, err := gateway.MarshalGatewayInPacket(addr, time.Now().Unix(), packet) + if err != nil { + panic(err) + } + return packet + }) + srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { + conn.Write(super.StringToBytes(fmt.Sprintf("Endpoint B: %s", packet))) + }) + if err := srv.Run(":8890"); err != nil { + panic(err) + } +} + func TestGateway_Run(t *testing.T) { gw := gateway.NewGateway(server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)), new(Scanner)) gw.RegConnectionReceivePacketEventHandle(func(gateway *gateway.Gateway, conn *server.Conn, packet []byte) { - endpoint, err := gateway.GetEndpoint("test") + endpoint, err := gateway.GetConnEndpoint("test", conn) if err == nil { endpoint.Forward(conn, packet) }