From b5c25a3dc8842dcd9b6b457ed2a8fd195b24fe1d Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Tue, 9 May 2023 18:04:05 +0800 Subject: [PATCH] =?UTF-8?q?websocket=E6=94=AF=E6=8C=81=E8=BF=87=E6=BB=A4?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/errors.go | 15 +++++++++------ server/options.go | 33 +++++++++++++++++++++++++++++++++ server/server.go | 19 +++++++++++-------- 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/server/errors.go b/server/errors.go index 6513727..cb82bbe 100644 --- a/server/errors.go +++ b/server/errors.go @@ -3,10 +3,13 @@ package server import "errors" var ( - ErrConstructed = errors.New("the Server must be constructed using the server.New function") - ErrCanNotSupportNetwork = errors.New("can not support network") - ErrMessageTypePacketAttrs = errors.New("MessageTypePacket must contain *Conn and []byte") - ErrMessageTypeErrorAttrs = errors.New("MessageTypePacket must contain error and MessageErrorAction") - ErrNetworkOnlySupportHttp = errors.New("the current network mode is not compatible with HttpRouter, only NetworkHttp is supported") - ErrNetworkIncompatibleHttp = errors.New("the current network mode is not compatible with NetworkHttp") + ErrConstructed = errors.New("the Server must be constructed using the server.New function") + ErrCanNotSupportNetwork = errors.New("can not support network") + ErrMessageTypePacketAttrs = errors.New("MessageTypePacket must contain *Conn and []byte") + ErrMessageTypeErrorAttrs = errors.New("MessageTypePacket must contain error and MessageErrorAction") + ErrNetworkOnlySupportHttp = errors.New("the current network mode is not compatible with HttpRouter, only NetworkHttp is supported") + ErrNetworkIncompatibleHttp = errors.New("the current network mode is not compatible with NetworkHttp") + ErrWebsocketMessageTypeException = errors.New("unknown message type, will not work") + ErrNotWebsocketUseMessageType = errors.New("message type filtering only supports websocket and does not take effect") + ErrWebsocketIllegalMessageType = errors.New("illegal message type") ) diff --git a/server/options.go b/server/options.go index 3ad5eb8..83ac09a 100644 --- a/server/options.go +++ b/server/options.go @@ -5,6 +5,19 @@ import ( "go.uber.org/zap" ) +const ( + // WebsocketMessageTypeText 表示文本数据消息。文本消息负载被解释为 UTF-8 编码的文本数据 + WebsocketMessageTypeText = 1 + // WebsocketMessageTypeBinary 表示二进制数据消息 + WebsocketMessageTypeBinary = 2 + // WebsocketMessageTypeClose 表示关闭控制消息。可选消息负载包含数字代码和文本。使用 FormatCloseMessage 函数来格式化关闭消息负载 + WebsocketMessageTypeClose = 8 + // WebsocketMessageTypePing 表示 ping 控制消息。可选的消息负载是 UTF-8 编码的文本 + WebsocketMessageTypePing = 9 + // WebsocketMessageTypePong 表示一个 pong 控制消息。可选的消息负载是 UTF-8 编码的文本 + WebsocketMessageTypePong = 10 +) + type Option func(srv *Server) // WithProd 通过生产模式运行服务器 @@ -14,6 +27,26 @@ func WithProd() Option { } } +// WithWebsocketMessageType 设置仅支持特定类型的Websocket消息 +func WithWebsocketMessageType(messageTypes ...int) Option { + return func(srv *Server) { + if srv.network != NetworkWebsocket { + log.Warn("WitchWebsocketMessageType", zap.String("Network", string(srv.network)), zap.Error(ErrNotWebsocketUseMessageType)) + return + } + var supports = make(map[int]bool) + for _, messageType := range messageTypes { + switch messageType { + case WebsocketMessageTypeText, WebsocketMessageTypeBinary, WebsocketMessageTypeClose, WebsocketMessageTypePing, WebsocketMessageTypePong: + supports[messageType] = true + default: + log.Warn("WitchWebsocketMessageType", zap.Int("MessageType", messageType), zap.Error(ErrWebsocketMessageTypeException)) + } + } + srv.supportMessageTypes = supports + } +} + // WithMessageBufferSize 通过特定的消息缓冲池大小运行服务器 // - 默认大小为 1024 // - 消息数量超出这个值的时候,消息处理将会造成更大的开销(频繁创建新的结构体),同时服务器将输出警告内容 diff --git a/server/server.go b/server/server.go index 872ce39..cc6154b 100644 --- a/server/server.go +++ b/server/server.go @@ -47,12 +47,13 @@ func New(network Network, options ...Option) *Server { // Server 网络服务器 type Server struct { *event - network Network // 网络类型 - addr string // 侦听地址 - options []Option // 选项 - ginServer *gin.Engine // HTTP模式下的路由器 - httpServer *http.Server // HTTP模式下的服务器 - grpcServer *grpc.Server // GRPC模式下的服务器 + network Network // 网络类型 + addr string // 侦听地址 + options []Option // 选项 + ginServer *gin.Engine // HTTP模式下的路由器 + httpServer *http.Server // HTTP模式下的服务器 + grpcServer *grpc.Server // GRPC模式下的服务器 + supportMessageTypes map[int]bool // websocket模式下支持的消息类型 gServer *gNet // TCP或UDP模式下的服务器 messagePool *synchronization.Pool[*message] // 消息池 @@ -228,12 +229,14 @@ func (slf *Server) Run(addr string) error { if err := ws.SetReadDeadline(time.Now().Add(time.Second * 30)); err != nil { panic(err) } - _, packet, err := ws.ReadMessage() + messageType, packet, err := ws.ReadMessage() if err != nil { panic(err) } + if !slf.supportMessageTypes[messageType] { + panic(ErrWebsocketIllegalMessageType) + } slf.PushMessage(MessageTypePacket, conn, packet) - } }) go func() {