diff --git a/server/client/websocket.go b/server/client/websocket.go new file mode 100644 index 0000000..52a7e98 --- /dev/null +++ b/server/client/websocket.go @@ -0,0 +1,144 @@ +package client + +import ( + "github.com/gorilla/websocket" + "github.com/kercylan98/minotaur/server" + "github.com/kercylan98/minotaur/utils/concurrent" + "sync" + "time" +) + +// NewWebsocket 创建 websocket 客户端 +func NewWebsocket(addr string) *Websocket { + client := &Websocket{ + websocketEvents: new(websocketEvents), + addr: addr, + data: map[string]any{}, + } + return client +} + +// Websocket websocket 客户端 +type Websocket struct { + *websocketEvents + conn *websocket.Conn + addr string + data map[string]any + + mutex sync.Mutex + packetPool *concurrent.Pool[*websocketPacket] + packets []*websocketPacket +} + +// Run 启动 +func (slf *Websocket) Run() error { + ws, _, err := websocket.DefaultDialer.Dial(slf.addr, nil) + if err != nil { + return err + } + slf.conn = ws + var wait = new(sync.WaitGroup) + wait.Add(1) + go slf.writeLoop(wait) + wait.Wait() + go func() { + defer func() { + if err := recover(); err != nil { + slf.Close() + slf.OnConnectionClosedEvent(slf, err) + } + }() + slf.OnConnectionOpenedEvent(slf) + for slf.packetPool != nil { + messageType, packet, readErr := ws.ReadMessage() + if readErr != nil { + panic(readErr) + } + slf.OnConnectionReceivePacketEvent(slf, server.NewWSPacket(messageType, packet)) + } + }() + return nil +} + +// Close 关闭 +func (slf *Websocket) Close() { + if slf.packetPool != nil { + slf.packetPool.Close() + slf.packetPool = nil + } + slf.packets = nil +} + +// IsConnected 是否已连接 +func (slf *Websocket) IsConnected() bool { + return slf.packetPool != nil +} + +// GetData 获取数据 +func (slf *Websocket) GetData(key string) any { + return slf.data[key] +} + +// SetData 设置数据 +func (slf *Websocket) SetData(key string, value any) { + slf.data[key] = value +} + +// Write 向连接中写入数据 +// - messageType: websocket模式中指定消息类型 +func (slf *Websocket) Write(packet server.Packet) { + if slf.packetPool == nil { + return + } + cp := slf.packetPool.Get() + cp.websocketMessageType = packet.WebsocketType + cp.packet = packet.Data + slf.mutex.Lock() + slf.packets = append(slf.packets, cp) + slf.mutex.Unlock() +} + +// writeLoop 写循环 +func (slf *Websocket) writeLoop(wait *sync.WaitGroup) { + slf.packetPool = concurrent.NewPool[*websocketPacket](10*1024, + func() *websocketPacket { + return &websocketPacket{} + }, func(data *websocketPacket) { + data.packet = nil + data.websocketMessageType = 0 + data.callback = nil + }, + ) + defer func() { + if err := recover(); err != nil { + slf.Close() + } + }() + wait.Done() + for { + slf.mutex.Lock() + if slf.packetPool == nil { + return + } + if len(slf.packets) == 0 { + slf.mutex.Unlock() + time.Sleep(50 * time.Millisecond) + continue + } + packets := slf.packets[0:] + slf.packets = slf.packets[0:0] + slf.mutex.Unlock() + for i := 0; i < len(packets); i++ { + data := packets[i] + var err = slf.conn.WriteMessage(data.websocketMessageType, data.packet) + callback := data.callback + slf.packetPool.Release(data) + if callback != nil { + callback(err) + } + if err != nil { + panic(err) + } + } + } +} diff --git a/server/client/websocket_events.go b/server/client/websocket_events.go new file mode 100644 index 0000000..9a33ada --- /dev/null +++ b/server/client/websocket_events.go @@ -0,0 +1,48 @@ +package client + +import "github.com/kercylan98/minotaur/server" + +type ( + ConnectionClosedEventHandle func(conn *Websocket, err any) + ConnectionOpenedEventHandle func(conn *Websocket) + ConnectionReceivePacketEventHandle func(conn *Websocket, packet server.Packet) +) + +type websocketEvents struct { + connectionClosedEventHandles []ConnectionClosedEventHandle + connectionOpenedEventHandles []ConnectionOpenedEventHandle + connectionReceivePacketEventHandles []ConnectionReceivePacketEventHandle +} + +// RegConnectionClosedEvent 注册连接关闭事件 +func (slf *websocketEvents) RegConnectionClosedEvent(handle ConnectionClosedEventHandle) { + slf.connectionClosedEventHandles = append(slf.connectionClosedEventHandles, handle) +} + +func (slf *websocketEvents) OnConnectionClosedEvent(conn *Websocket, err any) { + for _, handle := range slf.connectionClosedEventHandles { + handle(conn, err) + } +} + +// RegConnectionOpenedEvent 注册连接打开事件 +func (slf *websocketEvents) RegConnectionOpenedEvent(handle ConnectionOpenedEventHandle) { + slf.connectionOpenedEventHandles = append(slf.connectionOpenedEventHandles, handle) +} + +func (slf *websocketEvents) OnConnectionOpenedEvent(conn *Websocket) { + for _, handle := range slf.connectionOpenedEventHandles { + handle(conn) + } +} + +// RegConnectionReceivePacketEvent 注册连接接收数据包事件 +func (slf *websocketEvents) RegConnectionReceivePacketEvent(handle ConnectionReceivePacketEventHandle) { + slf.connectionReceivePacketEventHandles = append(slf.connectionReceivePacketEventHandles, handle) +} + +func (slf *websocketEvents) OnConnectionReceivePacketEvent(conn *Websocket, packet server.Packet) { + for _, handle := range slf.connectionReceivePacketEventHandles { + handle(conn, packet) + } +} diff --git a/server/client/websocket_packet.go b/server/client/websocket_packet.go new file mode 100644 index 0000000..9ebf2fa --- /dev/null +++ b/server/client/websocket_packet.go @@ -0,0 +1,7 @@ +package client + +type websocketPacket struct { + websocketMessageType int // websocket 消息类型 + packet []byte // 数据包 + callback func(err error) // 回调函数 +}