diff --git a/game/builtin/player.go b/game/builtin/player.go index 265f06a..876178a 100644 --- a/game/builtin/player.go +++ b/game/builtin/player.go @@ -34,7 +34,7 @@ func (slf *Player[ID]) UseConn(conn *server.Conn) { } // Send 向该玩家发送数据 -func (slf *Player[ID]) Send(packet server.Packet) { +func (slf *Player[ID]) Send(packet []byte) { slf.conn.Write(packet) } diff --git a/server/client/client.go b/server/client/client.go new file mode 100644 index 0000000..b380276 --- /dev/null +++ b/server/client/client.go @@ -0,0 +1,158 @@ +package client + +import ( + "github.com/kercylan98/minotaur/utils/concurrent" + "sync" + "time" +) + +// NewClient 创建客户端 +func NewClient(core Core) *Client { + client := &Client{ + events: new(events), + core: core, + } + return client +} + +// Client 客户端 +type Client struct { + *events + core Core + mutex sync.Mutex + packetPool *concurrent.Pool[*Packet] + packets []*Packet + + accumulate []*Packet +} + +func (slf *Client) Run() error { + var wait = new(sync.WaitGroup) + wait.Add(1) + go slf.writeLoop(wait) + wait.Wait() + var runState = make(chan error) + go func() { + defer func() { + if err := recover(); err != nil { + slf.Close(err.(error)) + } + }() + slf.core.Run(runState, slf.onReceive) + }() + err := <-runState + if err != nil { + slf.Close() + return err + } + slf.OnConnectionOpenedEvent(slf) + return nil +} + +// IsConnected 是否已连接 +func (slf *Client) IsConnected() bool { + return slf.packetPool != nil +} + +// Close 关闭 +func (slf *Client) Close(err ...error) { + slf.core.Close() + if slf.packetPool != nil { + slf.packetPool.Close() + slf.packetPool = nil + } + slf.packets = nil + if len(err) > 0 { + slf.OnConnectionClosedEvent(slf, err[0]) + } else { + slf.OnConnectionClosedEvent(slf, nil) + } +} + +// WriteWS 向连接中写入指定 websocket 数据类型 +// - wst: websocket模式中指定消息类型 +func (slf *Client) WriteWS(wst int, packet []byte, callback ...func(err error)) { + slf.write(wst, packet, callback...) +} + +// Write 向连接中写入数据 +func (slf *Client) Write(packet []byte, callback ...func(err error)) { + slf.write(0, packet, callback...) +} + +// write 向连接中写入数据 +// - messageType: websocket模式中指定消息类型 +func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) { + cp := slf.packetPool.Get() + cp.wst = wst + cp.data = packet + if len(callback) > 0 { + cp.callback = callback[0] + } + if slf.packetPool == nil { + slf.accumulate = append(slf.accumulate, cp) + return + } + slf.mutex.Lock() + slf.packets = append(slf.packets, cp) + slf.mutex.Unlock() +} + +// writeLoop 写循环 +func (slf *Client) writeLoop(wait *sync.WaitGroup) { + slf.packetPool = concurrent.NewPool[*Packet](10*1024, + func() *Packet { + return &Packet{} + }, func(data *Packet) { + data.wst = 0 + data.data = nil + data.callback = nil + }, + ) + slf.mutex.Lock() + slf.packets = append(slf.packets, slf.accumulate...) + slf.accumulate = nil + slf.mutex.Unlock() + defer func() { + if err := recover(); err != nil { + slf.Close(err.(error)) + } + }() + wait.Done() + for { + slf.mutex.Lock() + if slf.packetPool == nil { + slf.mutex.Unlock() + return + } + if len(slf.packets) == 0 { + slf.mutex.Unlock() + time.Sleep(50 * time.Millisecond) + continue + } + packets := slf.packets[0:] + slf.packets = slf.packets[0:0] + slf.mutex.Unlock() + for i := 0; i < len(packets); i++ { + data := packets[i] + var err = slf.core.Write(data) + callback := data.callback + slf.packetPool.Release(data) + if callback != nil { + callback(err) + } + if err != nil { + panic(err) + } + } + } +} + +func (slf *Client) onReceive(wst int, packet []byte) { + slf.OnConnectionReceivePacketEvent(slf, wst, packet) +} + +// GetServerAddr 获取服务器地址 +func (slf *Client) GetServerAddr() string { + return slf.core.GetServerAddr() +} diff --git a/server/client/client_core.go b/server/client/client_core.go new file mode 100644 index 0000000..2a0444b --- /dev/null +++ b/server/client/client_core.go @@ -0,0 +1,17 @@ +package client + +type Core interface { + // Run 启动客户端 + // - runState: 运行状态,当客户端启动完成时,应该向该通道发送 error 或 nil + // - receive: 接收到数据包时应该将数据包发送到该函数,wst 表示 websocket 的数据类型,data 表示数据包 + Run(runState chan<- error, receive func(wst int, packet []byte)) + + // Write 向客户端写入数据包 + Write(packet *Packet) error + + // Close 关闭客户端 + Close() + + // GetServerAddr 获取服务器地址 + GetServerAddr() string +} diff --git a/server/client/client_events.go b/server/client/client_events.go new file mode 100644 index 0000000..be14590 --- /dev/null +++ b/server/client/client_events.go @@ -0,0 +1,46 @@ +package client + +type ( + ConnectionClosedEventHandle func(conn *Client, err any) + ConnectionOpenedEventHandle func(conn *Client) + ConnectionReceivePacketEventHandle func(conn *Client, wst int, packet []byte) +) + +type events struct { + ConnectionClosedEventHandles []ConnectionClosedEventHandle + ConnectionOpenedEventHandles []ConnectionOpenedEventHandle + ConnectionReceivePacketEventHandles []ConnectionReceivePacketEventHandle +} + +// RegConnectionClosedEvent 注册连接关闭事件 +func (slf *events) RegConnectionClosedEvent(handle ConnectionClosedEventHandle) { + slf.ConnectionClosedEventHandles = append(slf.ConnectionClosedEventHandles, handle) +} + +func (slf *events) OnConnectionClosedEvent(conn *Client, err any) { + for _, handle := range slf.ConnectionClosedEventHandles { + handle(conn, err) + } +} + +// RegConnectionOpenedEvent 注册连接打开事件 +func (slf *events) RegConnectionOpenedEvent(handle ConnectionOpenedEventHandle) { + slf.ConnectionOpenedEventHandles = append(slf.ConnectionOpenedEventHandles, handle) +} + +func (slf *events) OnConnectionOpenedEvent(conn *Client) { + for _, handle := range slf.ConnectionOpenedEventHandles { + handle(conn) + } +} + +// RegConnectionReceivePacketEvent 注册连接接收数据包事件 +func (slf *events) RegConnectionReceivePacketEvent(handle ConnectionReceivePacketEventHandle) { + slf.ConnectionReceivePacketEventHandles = append(slf.ConnectionReceivePacketEventHandles, handle) +} + +func (slf *events) OnConnectionReceivePacketEvent(conn *Client, wst int, packet []byte) { + for _, handle := range slf.ConnectionReceivePacketEventHandles { + handle(conn, wst, packet) + } +} diff --git a/server/client/packet.go b/server/client/packet.go index 8a6d016..040a58f 100644 --- a/server/client/packet.go +++ b/server/client/packet.go @@ -1,7 +1,7 @@ package client type Packet struct { - websocketMessageType int // websocket 消息类型 - packet []byte // 数据包 - callback func(err error) // 回调函数 + wst int // websocket 的数据类型 + data []byte // 数据包 + callback func(err error) // 回调函数 } diff --git a/server/client/uds.go b/server/client/uds.go index 59ade7a..d2a0bcf 100644 --- a/server/client/uds.go +++ b/server/client/uds.go @@ -1,156 +1,48 @@ package client import ( - "github.com/kercylan98/minotaur/server" - "github.com/kercylan98/minotaur/utils/concurrent" "net" - "sync" - "time" ) -// NewUnixDomainSocket 创建 unix domain socket 客户端 -func NewUnixDomainSocket(addr string) *UnixDomainSocket { - return &UnixDomainSocket{ - udsEvents: new(udsEvents), - addr: addr, - data: map[string]any{}, - } +func NewUnixDomainSocket(addr string) *Client { + return NewClient(&UnixDomainSocket{ + addr: addr, + }) } -// UnixDomainSocket unix domain socket 客户端 type UnixDomainSocket struct { - *udsEvents - conn net.Conn - addr string - data map[string]any - - mutex sync.Mutex - packetPool *concurrent.Pool[*Packet] - packets []*Packet - - accumulate []server.Packet + conn net.Conn + addr string + closed bool } -// Run 启动 -func (slf *UnixDomainSocket) Run() error { +func (slf *UnixDomainSocket) Run(runState chan<- error, receive func(wst int, packet []byte)) { c, err := net.Dial("unix", slf.addr) if err != nil { - return err - } - slf.conn = c - var wait = new(sync.WaitGroup) - wait.Add(1) - go slf.writeLoop(wait) - wait.Wait() - go func() { - defer func() { - if err := recover(); err != nil { - slf.Close() - slf.OnUDSConnectionClosedEvent(slf, err) - } - }() - slf.OnUDSConnectionOpenedEvent(slf) - packet := make([]byte, 1024) - for slf.packetPool != nil { - n, readErr := slf.conn.Read(packet) - if readErr != nil { - panic(readErr) - } - slf.OnUDSConnectionReceivePacketEvent(slf, server.NewPacket(packet[:n])) - } - }() - return nil -} - -// Close 关闭 -func (slf *UnixDomainSocket) Close() { - if slf.packetPool != nil { - slf.packetPool.Close() - slf.packetPool = nil - } - slf.packets = nil -} - -// IsConnected 是否已连接 -func (slf *UnixDomainSocket) IsConnected() bool { - return slf.packetPool != nil -} - -// GetData 获取数据 -func (slf *UnixDomainSocket) GetData(key string) any { - return slf.data[key] -} - -// SetData 设置数据 -func (slf *UnixDomainSocket) SetData(key string, value any) { - slf.data[key] = value -} - -// Write 向连接中写入数据 -func (slf *UnixDomainSocket) Write(packet server.Packet) { - if slf.packetPool == nil { - slf.accumulate = append(slf.accumulate, packet) + runState <- err return } - cp := slf.packetPool.Get() - cp.websocketMessageType = packet.WebsocketType - cp.packet = packet.Data - slf.mutex.Lock() - slf.packets = append(slf.packets, cp) - slf.mutex.Unlock() + slf.conn = c + runState <- nil + packet := make([]byte, 1024) + for !slf.closed { + n, readErr := slf.conn.Read(packet) + if readErr != nil { + panic(readErr) + } + receive(0, packet[:n]) + } } -// writeLoop 写循环 -func (slf *UnixDomainSocket) writeLoop(wait *sync.WaitGroup) { - slf.packetPool = concurrent.NewPool[*Packet](10*1024, - func() *Packet { - return &Packet{} - }, func(data *Packet) { - data.packet = nil - data.websocketMessageType = 0 - data.callback = nil - }, - ) - slf.mutex.Lock() - for _, packet := range slf.accumulate { - cp := slf.packetPool.Get() - cp.websocketMessageType = packet.WebsocketType - cp.packet = packet.Data - slf.packets = append(slf.packets, cp) - } - slf.accumulate = nil - slf.mutex.Unlock() - defer func() { - if err := recover(); err != nil { - slf.Close() - } - }() - wait.Done() - for { - slf.mutex.Lock() - if slf.packetPool == nil { - slf.mutex.Unlock() - return - } - if len(slf.packets) == 0 { - slf.mutex.Unlock() - time.Sleep(50 * time.Millisecond) - continue - } - packets := slf.packets[0:] - slf.packets = slf.packets[0:0] - slf.mutex.Unlock() - for i := 0; i < len(packets); i++ { - data := packets[i] - var _, err = slf.conn.Write(data.packet) - callback := data.callback - slf.packetPool.Release(data) - if callback != nil { - callback(err) - } - if err != nil { - panic(err) - } - } - } +func (slf *UnixDomainSocket) Write(packet *Packet) error { + _, err := slf.conn.Write(packet.data) + return err +} + +func (slf *UnixDomainSocket) Close() { + slf.closed = true +} + +func (slf *UnixDomainSocket) GetServerAddr() string { + return slf.addr } diff --git a/server/client/uds_events.go b/server/client/uds_events.go deleted file mode 100644 index cf3433d..0000000 --- a/server/client/uds_events.go +++ /dev/null @@ -1,48 +0,0 @@ -package client - -import "github.com/kercylan98/minotaur/server" - -type ( - UDSConnectionClosedEventHandle func(conn *UnixDomainSocket, err any) - UDSConnectionOpenedEventHandle func(conn *UnixDomainSocket) - UDSConnectionReceivePacketEventHandle func(conn *UnixDomainSocket, packet server.Packet) -) - -type udsEvents struct { - UDSConnectionClosedEventHandles []UDSConnectionClosedEventHandle - UDSConnectionOpenedEventHandles []UDSConnectionOpenedEventHandle - UDSConnectionReceivePacketEventHandles []UDSConnectionReceivePacketEventHandle -} - -// RegUDSConnectionClosedEvent 注册连接关闭事件 -func (slf *udsEvents) RegUDSConnectionClosedEvent(handle UDSConnectionClosedEventHandle) { - slf.UDSConnectionClosedEventHandles = append(slf.UDSConnectionClosedEventHandles, handle) -} - -func (slf *udsEvents) OnUDSConnectionClosedEvent(conn *UnixDomainSocket, err any) { - for _, handle := range slf.UDSConnectionClosedEventHandles { - handle(conn, err) - } -} - -// RegUDSConnectionOpenedEvent 注册连接打开事件 -func (slf *udsEvents) RegUDSConnectionOpenedEvent(handle UDSConnectionOpenedEventHandle) { - slf.UDSConnectionOpenedEventHandles = append(slf.UDSConnectionOpenedEventHandles, handle) -} - -func (slf *udsEvents) OnUDSConnectionOpenedEvent(conn *UnixDomainSocket) { - for _, handle := range slf.UDSConnectionOpenedEventHandles { - handle(conn) - } -} - -// RegUDSConnectionReceivePacketEvent 注册连接接收数据包事件 -func (slf *udsEvents) RegUDSConnectionReceivePacketEvent(handle UDSConnectionReceivePacketEventHandle) { - slf.UDSConnectionReceivePacketEventHandles = append(slf.UDSConnectionReceivePacketEventHandles, handle) -} - -func (slf *udsEvents) OnUDSConnectionReceivePacketEvent(conn *UnixDomainSocket, packet server.Packet) { - for _, handle := range slf.UDSConnectionReceivePacketEventHandles { - handle(conn, packet) - } -} diff --git a/server/client/uds_test.go b/server/client/uds_test.go index 86e0a66..2ec57af 100644 --- a/server/client/uds_test.go +++ b/server/client/uds_test.go @@ -8,21 +8,21 @@ import ( ) func TestUnixDomainSocket_Write(t *testing.T) { - var close = make(chan struct{}) + var closed = make(chan struct{}) srv := server.New(server.NetworkUnix) - srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet server.Packet) { - t.Log(packet) + srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { + t.Log(string(packet)) conn.Write(packet) }) srv.RegStartFinishEvent(func(srv *server.Server) { time.Sleep(time.Second) cli := client.NewUnixDomainSocket("./test.sock") - cli.RegUDSConnectionOpenedEvent(func(conn *client.UnixDomainSocket) { - conn.Write(server.NewPacketString("Hello~")) + cli.RegConnectionOpenedEvent(func(conn *client.Client) { + conn.Write([]byte("Hello~")) }) - cli.RegUDSConnectionReceivePacketEvent(func(conn *client.UnixDomainSocket, packet server.Packet) { + cli.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) { t.Log(packet) - close <- struct{}{} + closed <- struct{}{} }) if err := cli.Run(); err != nil { panic(err) @@ -34,6 +34,6 @@ func TestUnixDomainSocket_Write(t *testing.T) { } }() - <-close + <-closed srv.Shutdown() } diff --git a/server/client/websocket.go b/server/client/websocket.go index 59fd7b7..821019c 100644 --- a/server/client/websocket.go +++ b/server/client/websocket.go @@ -3,155 +3,50 @@ package client import ( "github.com/gorilla/websocket" "github.com/kercylan98/minotaur/server" - "github.com/kercylan98/minotaur/utils/concurrent" - "sync" - "time" ) // NewWebsocket 创建 websocket 客户端 -func NewWebsocket(addr string) *Websocket { - client := &Websocket{ - websocketEvents: new(websocketEvents), - addr: addr, - data: map[string]any{}, - } - return client +func NewWebsocket(addr string) *Client { + return NewClient(&Websocket{ + addr: addr, + }) } // Websocket websocket 客户端 type Websocket struct { - *websocketEvents - conn *websocket.Conn - addr string - data map[string]any - - mutex sync.Mutex - packetPool *concurrent.Pool[*Packet] - packets []*Packet - - accumulate []server.Packet + addr string + conn *websocket.Conn + clsoed bool } -// Run 启动 -func (slf *Websocket) Run() error { +func (slf *Websocket) Run(runState chan<- error, receive func(wst int, packet []byte)) { ws, _, err := websocket.DefaultDialer.Dial(slf.addr, nil) if err != nil { - return err - } - slf.conn = ws - var wait = new(sync.WaitGroup) - wait.Add(1) - go slf.writeLoop(wait) - wait.Wait() - go func() { - defer func() { - if err := recover(); err != nil { - slf.Close() - slf.OnWebsocketConnectionClosedEvent(slf, err) - } - }() - slf.OnWebsocketConnectionOpenedEvent(slf) - for slf.packetPool != nil { - messageType, packet, readErr := ws.ReadMessage() - if readErr != nil { - panic(readErr) - } - slf.OnWebsocketConnectionReceivePacketEvent(slf, server.NewWSPacket(messageType, packet)) - } - }() - return nil -} - -// Close 关闭 -func (slf *Websocket) Close() { - if slf.packetPool != nil { - slf.packetPool.Close() - slf.packetPool = nil - } - slf.packets = nil -} - -// IsConnected 是否已连接 -func (slf *Websocket) IsConnected() bool { - return slf.packetPool != nil -} - -// GetData 获取数据 -func (slf *Websocket) GetData(key string) any { - return slf.data[key] -} - -// SetData 设置数据 -func (slf *Websocket) SetData(key string, value any) { - slf.data[key] = value -} - -// Write 向连接中写入数据 -// - messageType: websocket模式中指定消息类型 -func (slf *Websocket) Write(packet server.Packet) { - if slf.packetPool == nil { - slf.accumulate = append(slf.accumulate, packet) + runState <- err return } - cp := slf.packetPool.Get() - cp.websocketMessageType = packet.WebsocketType - cp.packet = packet.Data - slf.mutex.Lock() - slf.packets = append(slf.packets, cp) - slf.mutex.Unlock() + slf.conn = ws + runState <- nil + for !slf.clsoed { + messageType, packet, readErr := ws.ReadMessage() + if readErr != nil { + panic(readErr) + } + receive(messageType, packet) + } } -// writeLoop 写循环 -func (slf *Websocket) writeLoop(wait *sync.WaitGroup) { - slf.packetPool = concurrent.NewPool[*Packet](10*1024, - func() *Packet { - return &Packet{} - }, func(data *Packet) { - data.packet = nil - data.websocketMessageType = 0 - data.callback = nil - }, - ) - slf.mutex.Lock() - for _, packet := range slf.accumulate { - cp := slf.packetPool.Get() - cp.websocketMessageType = packet.WebsocketType - cp.packet = packet.Data - slf.packets = append(slf.packets, cp) - } - slf.accumulate = nil - slf.mutex.Unlock() - defer func() { - if err := recover(); err != nil { - slf.Close() - } - }() - wait.Done() - for { - slf.mutex.Lock() - if slf.packetPool == nil { - slf.mutex.Unlock() - return - } - if len(slf.packets) == 0 { - slf.mutex.Unlock() - time.Sleep(50 * time.Millisecond) - continue - } - packets := slf.packets[0:] - slf.packets = slf.packets[0:0] - slf.mutex.Unlock() - for i := 0; i < len(packets); i++ { - data := packets[i] - var err = slf.conn.WriteMessage(data.websocketMessageType, data.packet) - callback := data.callback - slf.packetPool.Release(data) - if callback != nil { - callback(err) - } - if err != nil { - panic(err) - } - } +func (slf *Websocket) Write(packet *Packet) error { + if packet.wst == 0 { + packet.wst = server.WebsocketMessageTypeBinary } + return slf.conn.WriteMessage(packet.wst, packet.data) +} + +func (slf *Websocket) Close() { + slf.clsoed = true +} + +func (slf *Websocket) GetServerAddr() string { + return slf.addr } diff --git a/server/client/websocket_events.go b/server/client/websocket_events.go deleted file mode 100644 index 4b25b20..0000000 --- a/server/client/websocket_events.go +++ /dev/null @@ -1,48 +0,0 @@ -package client - -import "github.com/kercylan98/minotaur/server" - -type ( - WebsocketConnectionClosedEventHandle func(conn *Websocket, err any) - WebsocketConnectionOpenedEventHandle func(conn *Websocket) - WebsocketConnectionReceivePacketEventHandle func(conn *Websocket, packet server.Packet) -) - -type websocketEvents struct { - websocketConnectionClosedEventHandles []WebsocketConnectionClosedEventHandle - websocketConnectionOpenedEventHandles []WebsocketConnectionOpenedEventHandle - websocketConnectionReceivePacketEventHandles []WebsocketConnectionReceivePacketEventHandle -} - -// RegWebsocketConnectionClosedEvent 注册连接关闭事件 -func (slf *websocketEvents) RegWebsocketConnectionClosedEvent(handle WebsocketConnectionClosedEventHandle) { - slf.websocketConnectionClosedEventHandles = append(slf.websocketConnectionClosedEventHandles, handle) -} - -func (slf *websocketEvents) OnWebsocketConnectionClosedEvent(conn *Websocket, err any) { - for _, handle := range slf.websocketConnectionClosedEventHandles { - handle(conn, err) - } -} - -// RegWebsocketConnectionOpenedEvent 注册连接打开事件 -func (slf *websocketEvents) RegWebsocketConnectionOpenedEvent(handle WebsocketConnectionOpenedEventHandle) { - slf.websocketConnectionOpenedEventHandles = append(slf.websocketConnectionOpenedEventHandles, handle) -} - -func (slf *websocketEvents) OnWebsocketConnectionOpenedEvent(conn *Websocket) { - for _, handle := range slf.websocketConnectionOpenedEventHandles { - handle(conn) - } -} - -// RegWebsocketConnectionReceivePacketEvent 注册连接接收数据包事件 -func (slf *websocketEvents) RegWebsocketConnectionReceivePacketEvent(handle WebsocketConnectionReceivePacketEventHandle) { - slf.websocketConnectionReceivePacketEventHandles = append(slf.websocketConnectionReceivePacketEventHandles, handle) -} - -func (slf *websocketEvents) OnWebsocketConnectionReceivePacketEvent(conn *Websocket, packet server.Packet) { - for _, handle := range slf.websocketConnectionReceivePacketEventHandles { - handle(conn, packet) - } -} diff --git a/server/conn.go b/server/conn.go index 89a5ae5..e9af1d2 100644 --- a/server/conn.go +++ b/server/conn.go @@ -1,12 +1,14 @@ package server import ( + "context" "github.com/gorilla/websocket" "github.com/kercylan98/minotaur/utils/concurrent" - "github.com/kercylan98/minotaur/utils/super" + "github.com/kercylan98/minotaur/utils/log" "github.com/panjf2000/gnet" "github.com/xtaci/kcp-go/v5" "net" + "runtime/debug" "strings" "sync" "time" @@ -15,11 +17,14 @@ import ( // newKcpConn 创建一个处理KCP的连接 func newKcpConn(server *Server, session *kcp.UDPSession) *Conn { c := &Conn{ - server: server, - remoteAddr: session.RemoteAddr(), - ip: session.RemoteAddr().String(), - kcp: session, - data: map[any]any{}, + ctx: server.ctx, + connection: &connection{ + server: server, + remoteAddr: session.RemoteAddr(), + ip: session.RemoteAddr().String(), + kcp: session, + data: map[any]any{}, + }, } if index := strings.LastIndex(c.ip, ":"); index != -1 { c.ip = c.ip[0:index] @@ -34,11 +39,14 @@ func newKcpConn(server *Server, session *kcp.UDPSession) *Conn { // newKcpConn 创建一个处理GNet的连接 func newGNetConn(server *Server, conn gnet.Conn) *Conn { c := &Conn{ - server: server, - remoteAddr: conn.RemoteAddr(), - ip: conn.RemoteAddr().String(), - gn: conn, - data: map[any]any{}, + ctx: server.ctx, + connection: &connection{ + server: server, + remoteAddr: conn.RemoteAddr(), + ip: conn.RemoteAddr().String(), + gn: conn, + data: map[any]any{}, + }, } if index := strings.LastIndex(c.ip, ":"); index != -1 { c.ip = c.ip[0:index] @@ -53,11 +61,14 @@ func newGNetConn(server *Server, conn gnet.Conn) *Conn { // newKcpConn 创建一个处理WebSocket的连接 func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn { c := &Conn{ - server: server, - remoteAddr: ws.RemoteAddr(), - ip: ip, - ws: ws, - data: map[any]any{}, + ctx: server.ctx, + connection: &connection{ + server: server, + remoteAddr: ws.RemoteAddr(), + ip: ip, + ws: ws, + data: map[any]any{}, + }, } var wait = new(sync.WaitGroup) wait.Add(1) @@ -69,18 +80,13 @@ func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn { // newGatewayConn 创建一个处理网关消息的连接 func newGatewayConn(conn *Conn, connId string) *Conn { c := &Conn{ - server: conn.server, - data: map[any]any{}, + //ctx: server.ctx, + connection: &connection{ + server: conn.server, + data: map[any]any{}, + }, } - c.gw = func(packet Packet) { - var gp = GP{ - C: connId, - WT: packet.WebsocketType, - D: packet.Data, - T: time.Now().UnixNano(), - } - pd := super.MarshalJSON(&gp) - packet.Data = append(pd, 0xff) + c.gw = func(packet []byte) { conn.Write(packet) } return c @@ -89,10 +95,13 @@ func newGatewayConn(conn *Conn, connId string) *Conn { // NewEmptyConn 创建一个适用于测试的空连接 func NewEmptyConn(server *Server) *Conn { c := &Conn{ - server: server, - remoteAddr: &net.TCPAddr{}, - ip: "0.0.0.0:0", - data: map[any]any{}, + ctx: server.ctx, + connection: &connection{ + server: server, + remoteAddr: &net.TCPAddr{}, + ip: "0.0.0.0:0", + data: map[any]any{}, + }, } var wait = new(sync.WaitGroup) wait.Add(1) @@ -101,15 +110,21 @@ func NewEmptyConn(server *Server) *Conn { return c } -// Conn 服务器连接 +// Conn 服务器连接单次会话的包装 type Conn struct { + *connection + ctx context.Context +} + +// connection 长久保持的连接 +type connection struct { server *Server remoteAddr net.Addr ip string ws *websocket.Conn gn gnet.Conn kcp *kcp.UDPSession - gw func(packet Packet) + gw func(packet []byte) data map[any]any mutex sync.Mutex packetPool *concurrent.Pool[*connPacket] @@ -174,7 +189,7 @@ func (slf *Conn) Close() { slf.packets = nil } -// SetData 设置连接数据 +// SetData 设置连接数据,该数据将在连接关闭前始终存在 func (slf *Conn) SetData(key, value any) *Conn { slf.data[key] = value return slf @@ -185,6 +200,17 @@ func (slf *Conn) GetData(key any) any { return slf.data[key] } +// SetMessageData 设置消息数据,该数据将在消息处理完成后释放 +func (slf *Conn) SetMessageData(key, value any) *Conn { + slf.ctx = context.WithValue(slf.ctx, key, value) + return slf +} + +// GetMessageData 获取消息数据 +func (slf *Conn) GetMessageData(key any) any { + return slf.ctx.Value(key) +} + // ReleaseData 释放数据 func (slf *Conn) ReleaseData() *Conn { for k := range slf.data { @@ -198,28 +224,21 @@ func (slf *Conn) IsWebsocket() bool { return slf.server.network == NetworkWebsocket } -// Write 向连接中写入数据 -// - messageType: websocket模式中指定消息类型 -func (slf *Conn) Write(packet Packet) { - if slf.gw != nil { - slf.gw(packet) - return - } - packet = slf.server.OnConnectionWritePacketBeforeEvent(slf, packet) - if slf.packetPool == nil { - return - } - cp := slf.packetPool.Get() - cp.websocketMessageType = packet.WebsocketType - cp.packet = packet.Data - slf.mutex.Lock() - slf.packets = append(slf.packets, cp) - slf.mutex.Unlock() +// GetWST 获取websocket消息类型 +func (slf *Conn) GetWST() int { + wst, _ := slf.ctx.Value(contextKeyWST).(int) + return wst } -// WriteWithCallback 与 Write 相同,但是会在写入完成后调用 callback -// - 当 callback 为 nil 时,与 Write 相同 -func (slf *Conn) WriteWithCallback(packet Packet, callback func(err error)) { +// SetWST 设置websocket消息类型 +func (slf *Conn) SetWST(wst int) *Conn { + slf.ctx = context.WithValue(slf.ctx, contextKeyWST, wst) + return slf +} + +// Write 向连接中写入数据 +// - messageType: websocket模式中指定消息类型 +func (slf *Conn) Write(packet []byte, callback ...func(err error)) { if slf.gw != nil { slf.gw(packet) return @@ -229,9 +248,11 @@ func (slf *Conn) WriteWithCallback(packet Packet, callback func(err error)) { return } cp := slf.packetPool.Get() - cp.websocketMessageType = packet.WebsocketType - cp.packet = packet.Data - cp.callback = callback + cp.wst = slf.GetWST() + cp.packet = packet + if len(callback) > 0 { + cp.callback = callback[0] + } slf.mutex.Lock() slf.packets = append(slf.packets, cp) slf.mutex.Unlock() @@ -243,14 +264,16 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) { func() *connPacket { return &connPacket{} }, func(data *connPacket) { + data.wst = 0 data.packet = nil - data.websocketMessageType = 0 data.callback = nil }, ) defer func() { if err := recover(); err != nil { slf.Close() + log.Error("WriteLoop", log.Any("Error", err)) + debug.PrintStack() } }() wait.Done() @@ -271,7 +294,7 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) { data := packets[i] var err error if slf.IsWebsocket() { - err = slf.ws.WriteMessage(data.websocketMessageType, data.packet) + err = slf.ws.WriteMessage(data.wst, data.packet) } else { if slf.gn != nil { switch slf.server.network { diff --git a/server/conn_packet.go b/server/conn_packet.go index 0087d53..ea4708b 100644 --- a/server/conn_packet.go +++ b/server/conn_packet.go @@ -2,7 +2,7 @@ package server // connPacket 连接包 type connPacket struct { - websocketMessageType int // websocket 消息类型 - packet []byte // 数据包 - callback func(err error) // 回调函数 + wst int // websocket消息类型 + packet []byte // 数据包 + callback func(err error) // 回调函数 } diff --git a/server/constants.go b/server/constants.go index 79b6c80..b5033b0 100644 --- a/server/constants.go +++ b/server/constants.go @@ -26,3 +26,7 @@ const ( DefaultAsyncPoolSize = 256 DefaultWebsocketReadDeadline = 30 * time.Second ) + +const ( + contextKeyWST = "_wst" // WebSocket 消息类型 +) diff --git a/server/event.go b/server/event.go index 8f1d926..c6bb699 100644 --- a/server/event.go +++ b/server/event.go @@ -14,7 +14,7 @@ import ( type StartBeforeEventHandle func(srv *Server) type StartFinishEventHandle func(srv *Server) type StopEventHandle func(srv *Server) -type ConnectionReceivePacketEventHandle func(srv *Server, conn *Conn, packet Packet) +type ConnectionReceivePacketEventHandle func(srv *Server, conn *Conn, packet []byte) type ConnectionOpenedEventHandle func(srv *Server, conn *Conn) type ConnectionClosedEventHandle func(srv *Server, conn *Conn, err any) type ReceiveCrossPacketEventHandle func(srv *Server, senderServerId int64, packet []byte) @@ -22,7 +22,7 @@ type MessageErrorEventHandle func(srv *Server, message *Message, err error) type MessageLowExecEventHandle func(srv *Server, message *Message, cost time.Duration) type ConsoleCommandEventHandle func(srv *Server) type ConnectionOpenedAfterEventHandle func(srv *Server, conn *Conn) -type ConnectionWritePacketBeforeEventHandle func(srv *Server, conn *Conn, packet Packet) Packet +type ConnectionWritePacketBeforeEventHandle func(srv *Server, conn *Conn, packet []byte) []byte type ShuntChannelCreatedEventHandle func(srv *Server, guid int64) type ShuntChannelClosedEventHandle func(srv *Server, guid int64) type ConnectionPacketPreprocessEventHandle func(srv *Server, conn *Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) @@ -201,7 +201,7 @@ func (slf *event) RegConnectionReceivePacketEvent(handle ConnectionReceivePacket log.Info("Server", log.String("RegEvent", runtimes.CurrentRunningFuncName()), log.String("handle", reflect.TypeOf(handle).String())) } -func (slf *event) OnConnectionReceivePacketEvent(conn *Conn, packet Packet) { +func (slf *event) OnConnectionReceivePacketEvent(conn *Conn, packet []byte) { slf.connectionReceivePacketEventHandles.RangeValue(func(index int, value ConnectionReceivePacketEventHandle) bool { value(slf.Server, conn, packet) return true @@ -278,7 +278,7 @@ func (slf *event) RegConnectionWritePacketBeforeEvent(handle ConnectionWritePack log.Info("Server", log.String("RegEvent", runtimes.CurrentRunningFuncName()), log.String("handle", reflect.TypeOf(handle).String())) } -func (slf *event) OnConnectionWritePacketBeforeEvent(conn *Conn, packet Packet) (newPacket Packet) { +func (slf *event) OnConnectionWritePacketBeforeEvent(conn *Conn, packet []byte) (newPacket []byte) { if slf.connectionWritePacketBeforeHandles.Len() == 0 { return packet } diff --git a/server/gateway/endpoint.go b/server/gateway/endpoint.go index 4715fed..b90f22f 100644 --- a/server/gateway/endpoint.go +++ b/server/gateway/endpoint.go @@ -1,18 +1,20 @@ package gateway import ( + "github.com/alphadose/haxmap" "github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server/client" - "github.com/kercylan98/minotaur/utils/super" "time" ) // NewEndpoint 创建网关端点 -func NewEndpoint(name, address string, options ...EndpointOption) *Endpoint { +func NewEndpoint(gateway *Gateway, name string, client *client.Client, options ...EndpointOption) *Endpoint { endpoint := &Endpoint{ - client: client.NewWebsocket(address), - name: name, - address: address, + gateway: gateway, + client: client, + name: name, + address: client.GetServerAddr(), + connections: haxmap.New[string, *server.Conn](), } for _, option := range options { option(endpoint) @@ -22,19 +24,37 @@ func NewEndpoint(name, address string, options ...EndpointOption) *Endpoint { return 1 / (1 + 1.5*time.Duration(costUnixNano).Seconds()) } } - endpoint.client.RegWebsocketConnectionClosedEvent(endpoint.onConnectionClosed) - endpoint.client.RegWebsocketConnectionReceivePacketEvent(endpoint.onConnectionReceivePacket) + endpoint.client.RegConnectionClosedEvent(endpoint.onConnectionClosed) + endpoint.client.RegConnectionReceivePacketEvent(endpoint.onConnectionReceivePacket) return endpoint } // Endpoint 网关端点 type Endpoint struct { - client *client.Websocket // 端点客户端 - name string // 端点名称 - address string // 端点地址 - state float64 // 端点健康值(0为不可用,越高越优) - offline bool // 离线 - evaluator func(costUnixNano float64) float64 // 端点健康值评估函数 + gateway *Gateway + client *client.Client // 端点客户端 + name string // 端点名称 + address string // 端点地址 + state float64 // 端点健康值(0为不可用,越高越优) + offline bool // 离线 + evaluator func(costUnixNano float64) float64 // 端点健康值评估函数 + connections *haxmap.Map[string, *server.Conn] // 连接列表 +} + +// Link 连接端点 +func (slf *Endpoint) Link(conn *server.Conn) { + slf.connections.Set(conn.GetID(), conn) +} + +// Unlink 断开连接 +func (slf *Endpoint) Unlink(conn *server.Conn) { + slf.connections.Del(conn.GetID()) +} + +// GetLink 获取连接 +func (slf *Endpoint) GetLink(id string) *server.Conn { + conn, _ := slf.connections.Get(id) + return conn } // Offline 离线 @@ -55,24 +75,28 @@ func (slf *Endpoint) Connect() { } // Write 写入数据 -func (slf *Endpoint) Write(packet server.Packet) { - slf.client.Write(packet) +func (slf *Endpoint) Write(packet []byte, callback ...func(err error)) { + slf.client.Write(packet, callback...) +} + +// WriteWS 写入 websocket 数据 +func (slf *Endpoint) WriteWS(wst int, packet []byte, callback ...func(err error)) { + slf.client.WriteWS(wst, packet, callback...) } // onConnectionClosed 与端点连接断开事件 -func (slf *Endpoint) onConnectionClosed(conn *client.Websocket, err any) { +func (slf *Endpoint) onConnectionClosed(conn *client.Client, err any) { if !slf.offline { go slf.Connect() } } // onConnectionReceivePacket 接收到来自端点的数据包事件 -func (slf *Endpoint) onConnectionReceivePacket(conn *client.Websocket, packet server.Packet) { - var gp server.GP - if err := super.UnmarshalJSON(packet.Data[:len(packet.Data)-1], &gp); err != nil { +func (slf *Endpoint) onConnectionReceivePacket(conn *client.Client, wst int, packet []byte) { + addr, sendTime, packet, err := UnmarshalGatewayInPacket(packet) + if err != nil { panic(err) } - cur := time.Now().UnixNano() - slf.state = slf.evaluator(float64(cur - gp.T)) - conn.GetData(gp.C).(*server.Conn).Write(server.NewWSPacket(gp.WT, gp.D)) + slf.state = slf.evaluator(float64(time.Now().UnixNano() - sendTime)) + slf.GetLink(addr).SetWST(wst).Write(packet) } diff --git a/server/gateway/endpoint_manager.go b/server/gateway/endpoint_manager.go index ecda703..7695aa6 100644 --- a/server/gateway/endpoint_manager.go +++ b/server/gateway/endpoint_manager.go @@ -1,7 +1,7 @@ package gateway import ( - "github.com/kercylan98/minotaur/server" + "github.com/alphadose/haxmap" "github.com/kercylan98/minotaur/utils/concurrent" "github.com/kercylan98/minotaur/utils/random" ) @@ -10,7 +10,7 @@ import ( func NewEndpointManager() *EndpointManager { em := &EndpointManager{ endpoints: concurrent.NewBalanceMap[string, []*Endpoint](), - memory: concurrent.NewBalanceMap[string, *Endpoint](), + memory: haxmap.New[string, *Endpoint](), selector: func(endpoints []*Endpoint) *Endpoint { return endpoints[random.Int(0, len(endpoints)-1)] }, @@ -21,13 +21,15 @@ func NewEndpointManager() *EndpointManager { // EndpointManager 网关端点管理器 type EndpointManager struct { endpoints *concurrent.BalanceMap[string, []*Endpoint] - memory *concurrent.BalanceMap[string, *Endpoint] + memory *haxmap.Map[string, *Endpoint] selector func([]*Endpoint) *Endpoint } // GetEndpoint 获取端点 -func (slf *EndpointManager) GetEndpoint(name string, conn *server.Conn) (*Endpoint, error) { - endpoint, exist := slf.memory.GetExist(conn.GetID()) +// - name: 端点名称 +// - id: 使用端点的连接标识 +func (slf *EndpointManager) GetEndpoint(name, id string) (*Endpoint, error) { + endpoint, exist := slf.memory.Get(id) if exist { return endpoint, nil } @@ -53,7 +55,7 @@ func (slf *EndpointManager) GetEndpoint(name string, conn *server.Conn) (*Endpoi if endpoint == nil { return nil, ErrEndpointNotExists } - slf.memory.Set(conn.GetID(), endpoint) + slf.memory.Set(id, endpoint) return endpoint, nil } diff --git a/server/gateway/gateway.go b/server/gateway/gateway.go index 80666e8..13a4018 100644 --- a/server/gateway/gateway.go +++ b/server/gateway/gateway.go @@ -2,7 +2,6 @@ package gateway import ( "github.com/kercylan98/minotaur/server" - "github.com/kercylan98/minotaur/utils/super" "math" ) @@ -36,28 +35,30 @@ func (slf *Gateway) Shutdown() { slf.srv.Shutdown() } -// onConnectionOpened 连接打开事件 func (slf *Gateway) onConnectionOpened(srv *server.Server, conn *server.Conn) { - endpoint, err := slf.GetEndpoint("test", conn) + endpoint, err := slf.GetEndpoint("test", conn.GetID()) if err != nil { conn.Close() return } - endpoint.client.SetData(conn.GetID(), conn) - conn.SetData("endpoint", endpoint) + endpoint.Link(conn) } // onConnectionReceivePacket 连接接收数据包事件 -func (slf *Gateway) onConnectionReceivePacket(srv *server.Server, conn *server.Conn, packet server.Packet) { - var gp = server.GP{ - C: conn.GetID(), - WT: packet.WebsocketType, - D: packet.Data, +func (slf *Gateway) onConnectionReceivePacket(srv *server.Server, conn *server.Conn, packet []byte) { + endpoint, err := slf.GetEndpoint("test", conn.GetID()) + if err != nil { + conn.Close() + return } - pd := super.MarshalJSON(&gp) - packet.Data = append(pd, 0xff) - var endpoint, exist = conn.GetData("endpoint").(*Endpoint) - if exist { + packet, err = MarshalGatewayOutPacket(conn.GetID(), packet) + if err != nil { + conn.Close() + return + } + if conn.IsWebsocket() { + endpoint.WriteWS(conn.GetWST(), packet) + } else { endpoint.Write(packet) } } diff --git a/server/gateway/gateway_test.go b/server/gateway/gateway_test.go index 744162e..8b3edd0 100644 --- a/server/gateway/gateway_test.go +++ b/server/gateway/gateway_test.go @@ -3,15 +3,39 @@ package gateway_test import ( "fmt" "github.com/kercylan98/minotaur/server" - gateway2 "github.com/kercylan98/minotaur/server/gateway" + "github.com/kercylan98/minotaur/server/client" + "github.com/kercylan98/minotaur/server/gateway" "testing" "time" ) func TestGateway_RunEndpointServer(t *testing.T) { srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) - srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet server.Packet) { - fmt.Println("endpoint receive packet", string(packet.Data)) + srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) { + fmt.Println(err) + }) + srv.RegConnectionPacketPreprocessEvent(func(srv *server.Server, conn *server.Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) { + addr, packet, err := gateway.UnmarshalGatewayOutPacket(packet) + if err != nil { + // 非网关的普通数据包 + return + } + usePacket(packet) + conn.SetMessageData("gw-addr", addr) + }) + srv.RegConnectionWritePacketBeforeEvent(func(srv *server.Server, conn *server.Conn, packet []byte) []byte { + addr, ok := conn.GetMessageData("gw-addr").(string) + if !ok { + return packet + } + packet, err := gateway.MarshalGatewayInPacket(addr, time.Now().Unix(), packet) + if err != nil { + panic(err) + } + return packet + }) + srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { + fmt.Println("endpoint receive packet", string(packet)) conn.Write(packet) }) if err := srv.Run(":8889"); err != nil { @@ -21,9 +45,9 @@ func TestGateway_RunEndpointServer(t *testing.T) { func TestGateway_Run(t *testing.T) { srv := server.New(server.NetworkWebsocket, server.WithDeadlockDetect(time.Second*3)) - gw := gateway2.NewGateway(srv) + gw := gateway.NewGateway(srv) srv.RegStartFinishEvent(func(srv *server.Server) { - if err := gw.AddEndpoint(gateway2.NewEndpoint("test", "ws://127.0.0.1:8889")); err != nil { + if err := gw.AddEndpoint(gateway.NewEndpoint(gw, "test", client.NewWebsocket("ws://127.0.0.1:8889"))); err != nil { panic(err) } }) diff --git a/server/gateway/packet.go b/server/gateway/packet.go index 761fd8b..17bc884 100644 --- a/server/gateway/packet.go +++ b/server/gateway/packet.go @@ -1,7 +1,108 @@ package gateway -type Packet struct { - ConnID string - WebsocketType int - Data []byte +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "strconv" +) + +var packetIdentifier = []byte{0xDE, 0xAD, 0xBE, 0xEF} + +// MarshalGatewayOutPacket 将数据包转换为网关出网数据包 +// - | identifier(4) | ipv4(4) | port(2) | packet | +func MarshalGatewayOutPacket(addr string, packet []byte) ([]byte, error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ipBytes := net.ParseIP(host).To4() + if ipBytes == nil { + return nil, errors.New("invalid IPv4 address") + } + port, err := strconv.Atoi(portStr) + if err != nil || port < 0 || port > 65535 { + return nil, errors.New("invalid port number") + } + portBytes := []byte{byte(port >> 8), byte(port & 0xFF)} + + result := append(packetIdentifier, ipBytes...) + result = append(result, portBytes...) + result = append(result, packet...) + + return result, nil +} + +// UnmarshalGatewayOutPacket 将网关出网数据包转换为数据包 +// - | identifier(4) | ipv4(4) | port(2) | packet | +func UnmarshalGatewayOutPacket(data []byte) (addr string, packet []byte, err error) { + if len(data) < 10 { + err = errors.New("data is too short to contain an IPv4 address and a port") + return + } + if !compareBytes(data[:4], packetIdentifier) { + err = errors.New("invalid identifier") + return + } + ipAddr := net.IP(data[4:8]).String() + port := uint16(data[8])<<8 | uint16(data[9]) + addr = fmt.Sprintf("%s:%d", ipAddr, port) + packet = data[10:] + + return addr, packet, nil +} + +// MarshalGatewayInPacket 将数据包转换为网关入网数据包 +// - | ipv4(4) | port(2) | cost(4) | packet | +func MarshalGatewayInPacket(addr string, currentTime int64, packet []byte) ([]byte, error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ipBytes := net.ParseIP(host).To4() + if ipBytes == nil { + return nil, errors.New("invalid IPv4 address") + } + port, err := strconv.Atoi(portStr) + if err != nil || port < 0 || port > 65535 { + return nil, errors.New("invalid port number") + } + portBytes := []byte{byte(port >> 8), byte(port & 0xFF)} + costBytes := make([]byte, 4) + binary.BigEndian.PutUint32(costBytes, uint32(currentTime)) + + result := append(ipBytes, portBytes...) + result = append(result, costBytes...) + result = append(result, packet...) + + return result, nil +} + +// UnmarshalGatewayInPacket 将网关入网数据包转换为数据包 +// - | ipv4(4) | port(2) | cost(4) | packet | +func UnmarshalGatewayInPacket(data []byte) (addr string, sendTime int64, packet []byte, err error) { + if len(data) < 10 { + err = errors.New("data is too short") + return + } + ipAddr := net.IP(data[:4]).String() + port := uint16(data[4])<<8 | uint16(data[5]) + addr = fmt.Sprintf("%s:%d", ipAddr, port) + sendTime = int64(binary.BigEndian.Uint32(data[6:10])) + packet = data[10:] + + return addr, sendTime, packet, nil +} + +func compareBytes(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true } diff --git a/server/gnet.go b/server/gnet.go index 8231fad..03c3ec6 100644 --- a/server/gnet.go +++ b/server/gnet.go @@ -39,7 +39,7 @@ func (slf *gNet) AfterWrite(c gnet.Conn, b []byte) { } func (slf *gNet) React(packet []byte, c gnet.Conn) (out []byte, action gnet.Action) { - PushPacketMessage(slf.Server, c.Context().(*Conn), append(bytes.Clone(packet), 0)) + PushPacketMessage(slf.Server, c.Context().(*Conn), 0, append(bytes.Clone(packet), 0)) return nil, gnet.None } diff --git a/server/gp.go b/server/gp.go deleted file mode 100644 index e84defd..0000000 --- a/server/gp.go +++ /dev/null @@ -1,8 +0,0 @@ -package server - -type GP struct { - C string // 连接 ID - WT int // WebSocket 类型 - D []byte // 数据 - T int64 // 时间戳 -} diff --git a/server/lockstep/client.go b/server/lockstep/client.go index 1780764..90abfad 100644 --- a/server/lockstep/client.go +++ b/server/lockstep/client.go @@ -1,12 +1,10 @@ package lockstep -import "github.com/kercylan98/minotaur/server" - // Client 帧同步客户端接口定义 // - 客户端应该具备ID及写入数据包的实现 type Client[ID comparable] interface { // GetID 用户玩家ID GetID() ID // Write 写入数据包 - Write(packet server.Packet) + Write(packet []byte, callback ...func(err error)) } diff --git a/server/lockstep/evnets.go b/server/lockstep/evnets.go new file mode 100644 index 0000000..3f1a12a --- /dev/null +++ b/server/lockstep/evnets.go @@ -0,0 +1,3 @@ +package lockstep + +type StoppedEventHandle[ClientID comparable, Command any] func(lockstep *Lockstep[ClientID, Command]) diff --git a/server/lockstep/lockstep.go b/server/lockstep/lockstep.go index 4d1a3fb..7a962b6 100644 --- a/server/lockstep/lockstep.go +++ b/server/lockstep/lockstep.go @@ -2,8 +2,6 @@ package lockstep import ( "encoding/json" - "github.com/kercylan98/minotaur/component" - "github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/utils/concurrent" "github.com/kercylan98/minotaur/utils/timer" "sync" @@ -18,13 +16,13 @@ func NewLockstep[ClientID comparable, Command any](options ...Option[ClientID, C frames: concurrent.NewBalanceMap[int, []Command](), ticker: timer.GetTicker(10), frameRate: 15, - serialization: func(frame int, commands []Command) server.Packet { + serialization: func(frame int, commands []Command) []byte { frameStruct := struct { Frame int `json:"frame"` Commands []Command `json:"commands"` }{frame, commands} data, _ := json.Marshal(frameStruct) - return server.NewPacket(data) + return data }, clientCurrentFrame: concurrent.NewBalanceMap[ClientID, int](), } @@ -49,11 +47,11 @@ type Lockstep[ClientID comparable, Command any] struct { clientCurrentFrame *concurrent.BalanceMap[ClientID, int] // 客户端当前帧数 running atomic.Bool - frameRate int // 帧率(每秒N帧) - frameLimit int // 帧上限 - serialization func(frame int, commands []Command) server.Packet // 序列化函数 + frameRate int // 帧率(每秒N帧) + frameLimit int // 帧上限 + serialization func(frame int, commands []Command) []byte // 序列化函数 - lockstepStoppedEventHandles []component.LockstepStoppedEventHandle[ClientID, Command] + lockstepStoppedEventHandles []StoppedEventHandle[ClientID, Command] } // JoinClient 加入客户端到广播队列中 @@ -156,7 +154,7 @@ func (slf *Lockstep[ClientID, Command]) GetFrames() [][]Command { } // RegLockstepStoppedEvent 当广播停止时将触发被注册的事件处理函数 -func (slf *Lockstep[ClientID, Command]) RegLockstepStoppedEvent(handle component.LockstepStoppedEventHandle[ClientID, Command]) { +func (slf *Lockstep[ClientID, Command]) RegLockstepStoppedEvent(handle StoppedEventHandle[ClientID, Command]) { slf.lockstepStoppedEventHandles = append(slf.lockstepStoppedEventHandles, handle) } diff --git a/server/lockstep/lockstep_options.go b/server/lockstep/lockstep_options.go index 4be3379..a29b782 100644 --- a/server/lockstep/lockstep_options.go +++ b/server/lockstep/lockstep_options.go @@ -1,7 +1,5 @@ package lockstep -import "github.com/kercylan98/minotaur/server" - type Option[ClientID comparable, Command any] func(lockstep *Lockstep[ClientID, Command]) // WithFrameLimit 通过特定逻辑帧上限创建锁步(帧)同步组件 @@ -31,7 +29,7 @@ func WithFrameRate[ClientID comparable, Command any](frameRate int) Option[Clien // Frame int `json:"frame"` // Commands []Command `json:"commands"` // } -func WithSerialization[ClientID comparable, Command any](handle func(frame int, commands []Command) server.Packet) Option[ClientID, Command] { +func WithSerialization[ClientID comparable, Command any](handle func(frame int, commands []Command) []byte) Option[ClientID, Command] { return func(lockstep *Lockstep[ClientID, Command]) { lockstep.serialization = handle } diff --git a/server/message.go b/server/message.go index 8f7c17f..4b32257 100644 --- a/server/message.go +++ b/server/message.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "fmt" "reflect" @@ -104,10 +105,10 @@ func (slf MessageType) String() string { } // PushPacketMessage 向特定服务器中推送 MessageTypePacket 消息 -func PushPacketMessage(srv *Server, conn *Conn, packet []byte, mark ...any) { +func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ...any) { msg := srv.messagePool.Get() msg.t = MessageTypePacket - msg.attrs = append([]any{conn, packet}, mark...) + msg.attrs = append([]any{&Conn{ctx: context.WithValue(conn.ctx, contextKeyWST, wst), connection: conn.connection}, packet}, mark...) srv.pushMessage(msg) } diff --git a/server/packet.go b/server/packet.go deleted file mode 100644 index 9007324..0000000 --- a/server/packet.go +++ /dev/null @@ -1,42 +0,0 @@ -package server - -// NewPacket 创建一个数据包 -func NewPacket(data []byte) Packet { - return Packet{ - Data: data, - } -} - -// NewWSPacket 创建一个 websocket 数据包 -func NewWSPacket(websocketType int, data []byte) Packet { - return Packet{ - WebsocketType: websocketType, - Data: data, - } -} - -// NewPacketString 创建一个字符串数据包 -func NewPacketString(data string) Packet { - return Packet{ - Data: []byte(data), - } -} - -// NewWSPacketString 创建一个 websocket 字符串数据包 -func NewWSPacketString(websocketType int, data string) Packet { - return Packet{ - WebsocketType: websocketType, - Data: []byte(data), - } -} - -// Packet 数据包 -type Packet struct { - WebsocketType int // websocket 消息类型 - Data []byte // 数据 -} - -// String 转换为字符串 -func (slf Packet) String() string { - return string(slf.Data) -} diff --git a/server/server.go b/server/server.go index 7e7585f..a5f51b8 100644 --- a/server/server.go +++ b/server/server.go @@ -36,6 +36,7 @@ func New(network Network, options ...Option) *Server { online: concurrent.NewBalanceMap[string, *Conn](), closeChannel: make(chan struct{}, 1), systemSignal: make(chan os.Signal, 1), + ctx: context.Background(), } server.event = newEvent(server) @@ -96,6 +97,7 @@ type Server struct { channelGenerator func(guid int64) chan *Message // 消息管道生成器 shuntMatcher func(conn *Conn) (guid int64, allowToCreate bool) // 分流管道匹配器 messageCounter atomic.Int64 // 消息计数器 + ctx context.Context // 上下文 } // Run 使用特定地址运行服务器 @@ -212,7 +214,7 @@ func (slf *Server) Run(addr string) error { if err != nil { panic(err) } - PushPacketMessage(slf, conn, buf[:n]) + PushPacketMessage(slf, conn, 0, buf[:n]) } }(conn) } @@ -303,7 +305,7 @@ func (slf *Server) Run(addr string) error { if len(slf.supportMessageTypes) > 0 && !slf.supportMessageTypes[messageType] { panic(ErrWebsocketIllegalMessageType) } - PushPacketMessage(slf, conn, append(packet, byte(messageType))) + PushPacketMessage(slf, conn, messageType, packet) } }) go func() { @@ -629,25 +631,8 @@ func (slf *Server) dispatchMessage(msg *Message) { case MessageTypePacket: var conn = attrs[0].(*Conn) var packet = attrs[1].([]byte) - var wst = int(packet[len(packet)-1]) - if len(packet) >= 2 { - var ct = packet[len(packet)-2] - if ct == 0xff { - var gp GP - if err := super.UnmarshalJSON(packet[:len(packet)-2], &gp); err != nil { - panic(err) - } - packet = gp.D - conn = newGatewayConn(conn, gp.C) - } else { - packet = packet[:len(packet)-1] - } - } else { - packet = packet[:len(packet)-1] - } - if !slf.OnConnectionPacketPreprocessEvent(conn, packet, func(newPacket []byte) { packet = newPacket }) { - slf.OnConnectionReceivePacketEvent(conn, Packet{Data: packet, WebsocketType: wst}) + slf.OnConnectionReceivePacketEvent(conn, packet) } case MessageTypeError: err, action := attrs[0].(error), attrs[1].(MessageErrorAction) diff --git a/server/server_example_test.go b/server/server_example_test.go index 2848d63..8e6c512 100644 --- a/server/server_example_test.go +++ b/server/server_example_test.go @@ -11,7 +11,7 @@ func ExampleNew() { server.WithPProf("/debug/pprof"), ) - srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet server.Packet) { + srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { conn.Write(packet) }) diff --git a/utils/stream/slice.go b/utils/stream/slice.go index 561f2ba..34d21ba 100644 --- a/utils/stream/slice.go +++ b/utils/stream/slice.go @@ -51,7 +51,7 @@ func (slf Slice[V]) RandomKeep(n int) Slice[V] { if n >= length { return slf } - var hit = make([]int, length, length) + var hit = make([]int, length) for i := 0; i < n; i++ { hit[i] = 1 } @@ -71,7 +71,7 @@ func (slf Slice[V]) RandomDelete(n int) Slice[V] { if n >= length { return slf[:0] } - var hit = make([]int, length, length) + var hit = make([]int, length) for i := 0; i < n; i++ { hit[i] = 1 } diff --git a/utils/super/stack.go b/utils/super/stack.go index 760d469..25ea97c 100644 --- a/utils/super/stack.go +++ b/utils/super/stack.go @@ -20,7 +20,7 @@ type StackGo struct { // Wait 等待收集消息堆栈 // - 在调用 Wait 函数后,当前协程将会被挂起,直到调用 Stack 或 GiveUp 函数 func (slf *StackGo) Wait() { - slf.stack = make(chan *struct{}, 0) + slf.stack = make(chan *struct{}) if s := <-slf.stack; s != nil { slf.collect <- debug.Stack() }