feat: gateway 支持连接与某一端点保持持久通讯,支持将端点的所有连接切换到另一端点

This commit is contained in:
kercylan98 2023-09-01 20:25:23 +08:00
parent 1cbe8ecf56
commit 6d5aa599d7
3 changed files with 81 additions and 19 deletions

View File

@ -126,17 +126,6 @@ func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func(
return
}
cb := func(err error) {
if len(callback) > 0 {
callback[0](err)
}
if err != nil {
slf.connections.Del(conn.GetID())
} else {
slf.connections.Set(conn.GetID(), conn)
}
}
var superior *client.Client
var superiorCount = -1
for _, cli := range slf.client {
@ -147,6 +136,20 @@ func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func(
}
}
var cb = func(err error) {
if len(callback) > 0 {
callback[0](err)
}
if err != nil {
slf.connections.Del(conn.GetID())
} else {
slf.connections.Set(conn.GetID(), conn)
slf.gateway.cceLock.Lock()
slf.gateway.cce[conn.GetID()] = slf
slf.gateway.cceLock.Unlock()
}
}
if conn.IsWebsocket() {
superior.WriteWS(conn.GetWST(), packet, cb)
} else {

View File

@ -24,6 +24,7 @@ func NewGateway(srv *server.Server, scanner Scanner, options ...Option) *Gateway
ess: func(endpoints []*Endpoint) *Endpoint {
return endpoints[random.Int(0, len(endpoints)-1)]
},
cce: make(map[string]*Endpoint),
}
for _, option := range options {
option(gateway)
@ -36,11 +37,13 @@ type Gateway struct {
*events
srv *server.Server // 网关服务器核心
scanner Scanner // 端点扫描器
es map[string]map[string]*Endpoint // 端点列表
es map[string]map[string]*Endpoint // 端点列表 [name][address]
esm sync.Mutex // 端点列表锁
ess EndpointSelector // 端点选择器
closed bool // 网关是否已关闭
running bool // 网关是否正在运行
cce map[string]*Endpoint // 连接当前连接的端点 [conn.ID]
cceLock sync.RWMutex // 连接当前连接的端点锁
}
// Run 运行网关
@ -139,3 +142,30 @@ func (slf *Gateway) GetEndpoint(name string) (*Endpoint, error) {
}
return endpoint, nil
}
// GetConnEndpoint 获取一个可用的端点,如果客户端已经连接到了某个端点,将优先返回该端点
// - 当连接到的端点不可用或没有连接记录时,效果同 GetEndpoint 相同
// - 当连接行为为有状态时,推荐使用该方法
func (slf *Gateway) GetConnEndpoint(name string, conn *server.Conn) (*Endpoint, error) {
slf.cceLock.RLock()
endpoint, exist := slf.cce[conn.GetID()]
slf.cceLock.RUnlock()
if exist && endpoint.state > 0 {
return endpoint, nil
}
return slf.GetEndpoint(name)
}
// SwitchEndpoint 将端点端点的所有连接切换到另一个端点
func (slf *Gateway) SwitchEndpoint(source, dest *Endpoint) {
if source.name == dest.name && source.address == dest.address || source.state <= 0 || dest.state <= 0 {
return
}
slf.cceLock.Lock()
for id, endpoint := range slf.cce {
if endpoint == source {
slf.cce[id] = dest
}
}
slf.cceLock.Unlock()
}

View File

@ -1,10 +1,11 @@
package gateway_test
import (
"fmt"
"github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/server/client"
"github.com/kercylan98/minotaur/server/gateway"
"github.com/kercylan98/minotaur/utils/log"
"github.com/kercylan98/minotaur/utils/super"
"testing"
"time"
)
@ -15,6 +16,7 @@ type Scanner struct {
func (slf *Scanner) GetEndpoints() ([]*gateway.Endpoint, error) {
return []*gateway.Endpoint{
gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889"), gateway.WithEndpointConnectionPoolSize(10)),
gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8890"), gateway.WithEndpointConnectionPoolSize(10)),
}, nil
}
@ -22,7 +24,7 @@ func (slf *Scanner) GetInterval() time.Duration {
return time.Second
}
func TestGateway_RunEndpointServer(t *testing.T) {
func TestGateway_RunEndpointServerA(t *testing.T) {
srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3))
srv.RegConnectionPacketPreprocessEvent(func(srv *server.Server, conn *server.Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) {
addr, packet, err := gateway.UnmarshalGatewayOutPacket(packet)
@ -45,20 +47,47 @@ func TestGateway_RunEndpointServer(t *testing.T) {
return packet
})
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
conn.Write(packet)
})
srv.RegConnectionOpenedEvent(func(srv *server.Server, conn *server.Conn) {
log.Info("connection opened", log.String("conn", conn.GetID()))
conn.Write(super.StringToBytes(fmt.Sprintf("Endpoint A: %s", packet)))
})
if err := srv.Run(":8889"); err != nil {
panic(err)
}
}
func TestGateway_RunEndpointServerB(t *testing.T) {
srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3))
srv.RegConnectionPacketPreprocessEvent(func(srv *server.Server, conn *server.Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) {
addr, packet, err := gateway.UnmarshalGatewayOutPacket(packet)
if err != nil {
// 非网关的普通数据包
return
}
usePacket(packet)
conn.SetMessageData("gw-addr", addr)
})
srv.RegConnectionWritePacketBeforeEvent(func(srv *server.Server, conn *server.Conn, packet []byte) []byte {
addr, ok := conn.GetMessageData("gw-addr").(string)
if !ok {
return packet
}
packet, err := gateway.MarshalGatewayInPacket(addr, time.Now().Unix(), packet)
if err != nil {
panic(err)
}
return packet
})
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
conn.Write(super.StringToBytes(fmt.Sprintf("Endpoint B: %s", packet)))
})
if err := srv.Run(":8890"); err != nil {
panic(err)
}
}
func TestGateway_Run(t *testing.T) {
gw := gateway.NewGateway(server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)), new(Scanner))
gw.RegConnectionReceivePacketEventHandle(func(gateway *gateway.Gateway, conn *server.Conn, packet []byte) {
endpoint, err := gateway.GetEndpoint("test")
endpoint, err := gateway.GetConnEndpoint("test", conn)
if err == nil {
endpoint.Forward(conn, packet)
}