服务器连接异步读写

This commit is contained in:
kercylan98 2023-05-15 12:32:53 +08:00
parent b28badbaab
commit 57460ff40b
6 changed files with 104 additions and 82 deletions

View File

@ -5,6 +5,9 @@ import "github.com/kercylan98/minotaur/server"
// 无意义的测试main入口
func main() {
srv := server.New(server.NetworkWebsocket, server.WithConnectPacketDiversion(3, 2))
srv.RegConnectionReceiveWebsocketPacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte, messageType int) {
conn.Write(packet, messageType)
})
if err := srv.Run(":8999"); err != nil {
panic(err)
}

View File

@ -2,10 +2,13 @@ package server
import (
"github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/utils/synchronization"
"github.com/panjf2000/gnet"
"github.com/xtaci/kcp-go/v5"
"net"
"strings"
"sync"
"time"
)
// newKcpConn 创建一个处理KCP的连接
@ -24,6 +27,7 @@ func newKcpConn(server *Server, session *kcp.UDPSession) *Conn {
if index := strings.LastIndex(c.ip, ":"); index != -1 {
c.ip = c.ip[0:index]
}
go c.writeLoop()
return c
}
@ -42,12 +46,13 @@ func newGNetConn(server *Server, conn gnet.Conn) *Conn {
if index := strings.LastIndex(c.ip, ":"); index != -1 {
c.ip = c.ip[0:index]
}
go c.writeLoop()
return c
}
// newKcpConn 创建一个处理WebSocket的连接
func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn {
return &Conn{
c := &Conn{
server: server,
remoteAddr: ws.RemoteAddr(),
ip: ip,
@ -57,6 +62,8 @@ func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn {
},
data: map[any]any{},
}
go c.writeLoop()
return c
}
// Conn 服务器连接
@ -69,6 +76,9 @@ type Conn struct {
kcp *kcp.UDPSession
write func(data []byte) error
data map[any]any
mutex sync.Mutex
packetPool *synchronization.Pool[*connPacket]
packets []*connPacket
}
func (slf *Conn) RemoteAddr() net.Addr {
@ -86,19 +96,20 @@ func (slf *Conn) GetIP() string {
// Write 向连接中写入数据
// - messageType: websocket模式中指定消息类型
func (slf *Conn) Write(data []byte, messageType ...int) {
if slf.IsWebsocket() {
cp := slf.packetPool.Get()
if len(messageType) > 0 {
slf.server.PushMessage(MessageTypeWritePacket, slf, data, messageType[0])
} else {
slf.server.PushMessage(MessageTypeWritePacket, slf, data, -1)
}
} else {
slf.server.PushMessage(MessageTypeWritePacket, slf, data)
cp.websocketMessageType = messageType[0]
}
cp.packet = data
slf.mutex.Lock()
slf.packets = append(slf.packets, cp)
slf.mutex.Unlock()
}
// Close 关闭连接
func (slf *Conn) Close() {
slf.mutex.Lock()
defer slf.mutex.Unlock()
if slf.ws != nil {
_ = slf.ws.Close()
} else if slf.gn != nil {
@ -107,6 +118,9 @@ func (slf *Conn) Close() {
_ = slf.kcp.Close()
}
slf.write = nil
slf.packetPool.Close()
slf.packetPool = nil
slf.packets = nil
}
// SetData 设置连接数据
@ -132,3 +146,61 @@ func (slf *Conn) ReleaseData() *Conn {
func (slf *Conn) IsWebsocket() bool {
return slf.server.network == NetworkWebsocket
}
// writeLoop 写循环
func (slf *Conn) writeLoop() {
slf.packetPool = synchronization.NewPool[*connPacket](64,
func() *connPacket {
return &connPacket{}
}, func(data *connPacket) {
data.packet = nil
data.websocketMessageType = -1
},
)
defer func() {
if err := recover(); err != nil {
slf.Close()
}
}()
for {
slf.mutex.Lock()
if slf.packetPool == nil {
return
}
if len(slf.packets) == 0 {
slf.mutex.Unlock()
time.Sleep(50 * time.Millisecond)
continue
}
packets := slf.packets[0:]
slf.packets = slf.packets[:0]
slf.mutex.Unlock()
for _, data := range packets {
data := data
if len(data.packet) == 0 {
for _, packet := range packets {
slf.packetPool.Release(packet)
}
slf.Close()
return
}
var err error
if slf.IsWebsocket() {
if data.websocketMessageType == -1 {
data.websocketMessageType = slf.server.websocketWriteMessageType
}
err = slf.ws.WriteMessage(data.websocketMessageType, data.packet)
} else {
if slf.gn != nil {
err = slf.gn.AsyncWrite(data.packet)
} else if slf.kcp != nil {
_, err = slf.kcp.Write(data.packet)
}
}
slf.packetPool.Release(data)
if err != nil {
panic(err)
}
}
}
}

6
server/conn_packet.go Normal file
View File

@ -0,0 +1,6 @@
package server
type connPacket struct {
websocketMessageType int
packet []byte
}

View File

@ -17,5 +17,4 @@ var (
ErrPleaseUseWebsocketHandle = errors.New("in Websocket mode, please use the RegConnectionReceiveWebsocketPacketEvent function to register")
ErrPleaseUseOrdinaryPacketHandle = errors.New("non Websocket mode, please use the RegConnectionReceivePacketEvent function to register")
ErrOnlySupportSocket = errors.New("only supports Socket programming")
ErrWebsocketMessageTypeWritePacketAttrs = errors.New("MessageTypeWritePacket must contain *Conn and []byte or *Conn and []byte and MessageType(int)")
)

View File

@ -6,11 +6,6 @@ const (
// - []byte
MessageTypePacket MessageType = iota
// MessageTypeWritePacket 数据包消息类型:该类型的消息将对客户端进行写入
// - *server.Conn
// - []byte
MessageTypeWritePacket
// MessageTypeError 错误消息类型:根据不同的错误状态,将交由 Server 进行统一处理
// - error
// - server.MessageErrorAction
@ -84,38 +79,6 @@ func (slf MessageType) deconstructPacket(attrs ...any) (conn *Conn, packet []byt
return
}
func (slf MessageType) deconstructWebSocketWritePacket(attrs ...any) (conn *Conn, packet []byte, messageType int) {
messageType = -1
if len(attrs) != 3 {
panic(ErrWebsocketMessageTypeWritePacketAttrs)
}
var ok bool
if conn, ok = attrs[0].(*Conn); !ok {
panic(ErrWebsocketMessageTypeWritePacketAttrs)
}
if packet, ok = attrs[1].([]byte); !ok {
panic(ErrWebsocketMessageTypeWritePacketAttrs)
}
if messageType, ok = attrs[2].(int); !ok {
panic(ErrWebsocketMessageTypeWritePacketAttrs)
}
return
}
func (slf MessageType) deconstructWritePacket(attrs ...any) (conn *Conn, packet []byte) {
if len(attrs) != 2 {
panic(ErrMessageTypePacketAttrs)
}
var ok bool
if conn, ok = attrs[0].(*Conn); !ok {
panic(ErrMessageTypePacketAttrs)
}
if packet, ok = attrs[1].([]byte); !ok {
panic(ErrMessageTypePacketAttrs)
}
return
}
func (slf MessageType) deconstructError(attrs ...any) (err error, action MessageErrorAction) {
if len(attrs) != 2 {
panic(ErrMessageTypeErrorAttrs)

View File

@ -410,27 +410,6 @@ func (slf *Server) dispatchMessage(msg *message) {
conn, packet := msg.t.deconstructPacket(msg.attrs...)
slf.OnConnectionReceivePacketEvent(conn, packet)
}
case MessageTypeWritePacket:
if slf.network == NetworkWebsocket {
conn, packet, messageType := msg.t.deconstructWebSocketWritePacket(msg.attrs...)
if messageType == -1 {
messageType = slf.websocketWriteMessageType
}
if err := conn.ws.WriteMessage(messageType, packet); err != nil {
log.Debug("Server", zap.String("ConnID", conn.GetID()), zap.Error(err))
}
} else {
var err error
conn, packet := msg.t.deconstructPacket(msg.attrs...)
if conn.gn != nil {
err = conn.gn.AsyncWrite(packet)
} else if conn.kcp != nil {
_, err = conn.kcp.Write(packet)
}
if err != nil {
log.Debug("Server", zap.String("ConnID", conn.GetID()), zap.Error(err))
}
}
case MessageTypeError:
err, action := msg.t.deconstructError(msg.attrs...)
switch action {