feat: gateway 支持连接与某一端点保持持久通讯,支持将端点的所有连接切换到另一端点
This commit is contained in:
parent
1cbe8ecf56
commit
6d5aa599d7
|
@ -126,17 +126,6 @@ func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func(
|
|||
return
|
||||
}
|
||||
|
||||
cb := func(err error) {
|
||||
if len(callback) > 0 {
|
||||
callback[0](err)
|
||||
}
|
||||
if err != nil {
|
||||
slf.connections.Del(conn.GetID())
|
||||
} else {
|
||||
slf.connections.Set(conn.GetID(), conn)
|
||||
}
|
||||
}
|
||||
|
||||
var superior *client.Client
|
||||
var superiorCount = -1
|
||||
for _, cli := range slf.client {
|
||||
|
@ -147,6 +136,20 @@ func (slf *Endpoint) Forward(conn *server.Conn, packet []byte, callback ...func(
|
|||
}
|
||||
}
|
||||
|
||||
var cb = func(err error) {
|
||||
if len(callback) > 0 {
|
||||
callback[0](err)
|
||||
}
|
||||
if err != nil {
|
||||
slf.connections.Del(conn.GetID())
|
||||
} else {
|
||||
slf.connections.Set(conn.GetID(), conn)
|
||||
slf.gateway.cceLock.Lock()
|
||||
slf.gateway.cce[conn.GetID()] = slf
|
||||
slf.gateway.cceLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if conn.IsWebsocket() {
|
||||
superior.WriteWS(conn.GetWST(), packet, cb)
|
||||
} else {
|
||||
|
|
|
@ -24,6 +24,7 @@ func NewGateway(srv *server.Server, scanner Scanner, options ...Option) *Gateway
|
|||
ess: func(endpoints []*Endpoint) *Endpoint {
|
||||
return endpoints[random.Int(0, len(endpoints)-1)]
|
||||
},
|
||||
cce: make(map[string]*Endpoint),
|
||||
}
|
||||
for _, option := range options {
|
||||
option(gateway)
|
||||
|
@ -36,11 +37,13 @@ type Gateway struct {
|
|||
*events
|
||||
srv *server.Server // 网关服务器核心
|
||||
scanner Scanner // 端点扫描器
|
||||
es map[string]map[string]*Endpoint // 端点列表
|
||||
es map[string]map[string]*Endpoint // 端点列表 [name][address]
|
||||
esm sync.Mutex // 端点列表锁
|
||||
ess EndpointSelector // 端点选择器
|
||||
closed bool // 网关是否已关闭
|
||||
running bool // 网关是否正在运行
|
||||
cce map[string]*Endpoint // 连接当前连接的端点 [conn.ID]
|
||||
cceLock sync.RWMutex // 连接当前连接的端点锁
|
||||
}
|
||||
|
||||
// Run 运行网关
|
||||
|
@ -139,3 +142,30 @@ func (slf *Gateway) GetEndpoint(name string) (*Endpoint, error) {
|
|||
}
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
// GetConnEndpoint 获取一个可用的端点,如果客户端已经连接到了某个端点,将优先返回该端点
|
||||
// - 当连接到的端点不可用或没有连接记录时,效果同 GetEndpoint 相同
|
||||
// - 当连接行为为有状态时,推荐使用该方法
|
||||
func (slf *Gateway) GetConnEndpoint(name string, conn *server.Conn) (*Endpoint, error) {
|
||||
slf.cceLock.RLock()
|
||||
endpoint, exist := slf.cce[conn.GetID()]
|
||||
slf.cceLock.RUnlock()
|
||||
if exist && endpoint.state > 0 {
|
||||
return endpoint, nil
|
||||
}
|
||||
return slf.GetEndpoint(name)
|
||||
}
|
||||
|
||||
// SwitchEndpoint 将端点端点的所有连接切换到另一个端点
|
||||
func (slf *Gateway) SwitchEndpoint(source, dest *Endpoint) {
|
||||
if source.name == dest.name && source.address == dest.address || source.state <= 0 || dest.state <= 0 {
|
||||
return
|
||||
}
|
||||
slf.cceLock.Lock()
|
||||
for id, endpoint := range slf.cce {
|
||||
if endpoint == source {
|
||||
slf.cce[id] = dest
|
||||
}
|
||||
}
|
||||
slf.cceLock.Unlock()
|
||||
}
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
package gateway_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/kercylan98/minotaur/server"
|
||||
"github.com/kercylan98/minotaur/server/client"
|
||||
"github.com/kercylan98/minotaur/server/gateway"
|
||||
"github.com/kercylan98/minotaur/utils/log"
|
||||
"github.com/kercylan98/minotaur/utils/super"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -15,6 +16,7 @@ type Scanner struct {
|
|||
func (slf *Scanner) GetEndpoints() ([]*gateway.Endpoint, error) {
|
||||
return []*gateway.Endpoint{
|
||||
gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8889"), gateway.WithEndpointConnectionPoolSize(10)),
|
||||
gateway.NewEndpoint("test", client.NewWebsocket("ws://127.0.0.1:8890"), gateway.WithEndpointConnectionPoolSize(10)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -22,7 +24,7 @@ func (slf *Scanner) GetInterval() time.Duration {
|
|||
return time.Second
|
||||
}
|
||||
|
||||
func TestGateway_RunEndpointServer(t *testing.T) {
|
||||
func TestGateway_RunEndpointServerA(t *testing.T) {
|
||||
srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3))
|
||||
srv.RegConnectionPacketPreprocessEvent(func(srv *server.Server, conn *server.Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) {
|
||||
addr, packet, err := gateway.UnmarshalGatewayOutPacket(packet)
|
||||
|
@ -45,20 +47,47 @@ func TestGateway_RunEndpointServer(t *testing.T) {
|
|||
return packet
|
||||
})
|
||||
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
|
||||
conn.Write(packet)
|
||||
})
|
||||
srv.RegConnectionOpenedEvent(func(srv *server.Server, conn *server.Conn) {
|
||||
log.Info("connection opened", log.String("conn", conn.GetID()))
|
||||
conn.Write(super.StringToBytes(fmt.Sprintf("Endpoint A: %s", packet)))
|
||||
})
|
||||
if err := srv.Run(":8889"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGateway_RunEndpointServerB(t *testing.T) {
|
||||
srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3))
|
||||
srv.RegConnectionPacketPreprocessEvent(func(srv *server.Server, conn *server.Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) {
|
||||
addr, packet, err := gateway.UnmarshalGatewayOutPacket(packet)
|
||||
if err != nil {
|
||||
// 非网关的普通数据包
|
||||
return
|
||||
}
|
||||
usePacket(packet)
|
||||
conn.SetMessageData("gw-addr", addr)
|
||||
})
|
||||
srv.RegConnectionWritePacketBeforeEvent(func(srv *server.Server, conn *server.Conn, packet []byte) []byte {
|
||||
addr, ok := conn.GetMessageData("gw-addr").(string)
|
||||
if !ok {
|
||||
return packet
|
||||
}
|
||||
packet, err := gateway.MarshalGatewayInPacket(addr, time.Now().Unix(), packet)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return packet
|
||||
})
|
||||
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
|
||||
conn.Write(super.StringToBytes(fmt.Sprintf("Endpoint B: %s", packet)))
|
||||
})
|
||||
if err := srv.Run(":8890"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGateway_Run(t *testing.T) {
|
||||
gw := gateway.NewGateway(server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)), new(Scanner))
|
||||
gw.RegConnectionReceivePacketEventHandle(func(gateway *gateway.Gateway, conn *server.Conn, packet []byte) {
|
||||
endpoint, err := gateway.GetEndpoint("test")
|
||||
endpoint, err := gateway.GetConnEndpoint("test", conn)
|
||||
if err == nil {
|
||||
endpoint.Forward(conn, packet)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue