vRp.CD2g_test/server/internal/v2/network/wsbsocket_wrapper.go

163 lines
3.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package network
import (
"bytes"
"errors"
"fmt"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/kercylan98/minotaur/utils/super"
"github.com/panjf2000/gnet/v2"
"golang.org/x/net/context"
"io"
"time"
)
// newWebsocketWrapper 创建 websocket 包装器
func newWebsocketWrapper(ctx context.Context, conn gnet.Conn) *websocketWrapper {
wrapper := &websocketWrapper{
conn: conn,
upgraded: false,
active: time.Now(),
}
wrapper.ctx, wrapper.cancel = context.WithCancel(ctx)
return wrapper
}
// websocketWrapper websocket 包装器
type websocketWrapper struct {
ctx context.Context
cancel context.CancelFunc
conn gnet.Conn // 连接
upgraded bool // 是否已经升级
hs ws.Handshake // 握手信息
active time.Time // 活跃时间
buf bytes.Buffer // 缓冲区
header *ws.Header // 当前头部
cache bytes.Buffer // 缓存的数据
}
// readToBuffer 将数据读取到缓冲区
func (w *websocketWrapper) readToBuffer() error {
size := w.conn.InboundBuffered()
buf := make([]byte, size)
n, err := w.conn.Read(buf)
if err != nil {
return err
}
if n < size {
return fmt.Errorf("incomplete data, read buffer bytes failed! size: %d, read: %d", size, n)
}
w.buf.Write(buf)
return nil
}
// upgrade 升级
func (w *websocketWrapper) upgrade(upgrader ws.Upgrader) (err error) {
if w.upgraded {
return
}
buf := &w.buf
reader := bytes.NewReader(buf.Bytes())
n := reader.Len()
w.hs, err = upgrader.Upgrade(super.ReadWriter{
Reader: reader,
Writer: w.conn,
})
skip := n - reader.Len()
if err != nil {
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { //数据不完整,不跳过 buf 中的 skipN 字节(此时 buf 中存放的仅是部分 "handshake data" bytes下次再尝试读取
return
}
buf.Next(skip)
return err
}
buf.Next(skip)
w.upgraded = true
return
}
// decode 解码
func (w *websocketWrapper) decode() (messages []wsutil.Message, err error) {
if messages, err = w.read(); err != nil {
return
}
var result = make([]wsutil.Message, 0, len(messages))
for _, message := range messages {
if message.OpCode.IsControl() {
err = wsutil.HandleClientControlMessage(w.conn, message)
if err != nil {
return
}
continue
}
if message.OpCode == ws.OpText || message.OpCode == ws.OpBinary {
result = append(result, message)
}
}
return result, nil
}
// decode 解码
func (w *websocketWrapper) read() (messages []wsutil.Message, err error) {
var buf = &w.buf
for {
// 从缓冲区中读取 header 信息并写入到缓存中
if w.header == nil {
if buf.Len() < ws.MinHeaderSize {
return // 不完整的数据,不做处理
}
var header ws.Header
if buf.Len() >= ws.MaxHeaderSize {
header, err = ws.ReadHeader(buf)
if err != nil {
return
}
} else {
// 使用新的 reader 尝试读取 header避免 header 不完整
reader := bytes.NewReader(buf.Bytes())
currLen := reader.Len()
header, err = ws.ReadHeader(reader)
skip := currLen - reader.Len()
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
return messages, nil
}
buf.Next(skip)
return nil, err
}
buf.Next(skip)
}
w.header = &header
if err = ws.WriteHeader(&w.cache, header); err != nil {
return nil, err
}
}
// 将缓冲区数据读出并写入缓存
if n := int(w.header.Length); n > 0 {
if buf.Len() < n {
return // 不完整的数据,不做处理
}
if _, err = io.CopyN(&w.cache, buf, int64(n)); err != nil {
return
}
}
// 消息已完整,读取数据帧,否则数据将被分割为多个数据帧
if w.header.Fin {
messages, err = wsutil.ReadClientMessage(&w.cache, messages)
if err != nil {
return
}
w.cache.Reset()
}
w.header = nil
}
}