Merge branch 'develop'

This commit is contained in:
kercylan98 2023-09-05 11:22:23 +08:00
commit 067a4fb514
8 changed files with 281 additions and 145 deletions

1
go.mod
View File

@ -21,6 +21,7 @@ require (
github.com/xtaci/kcp-go/v5 v5.6.3 github.com/xtaci/kcp-go/v5 v5.6.3
go.uber.org/atomic v1.10.0 go.uber.org/atomic v1.10.0
go.uber.org/zap v1.25.0 go.uber.org/zap v1.25.0
golang.org/x/time v0.3.0
google.golang.org/grpc v1.57.0 google.golang.org/grpc v1.57.0
) )

View File

@ -8,7 +8,6 @@ import (
// NewClient 创建客户端 // NewClient 创建客户端
func NewClient(core Core) *Client { func NewClient(core Core) *Client {
client := &Client{ client := &Client{
cond: sync.NewCond(&sync.Mutex{}),
events: new(events), events: new(events),
core: core, core: core,
} }
@ -24,11 +23,11 @@ func CloneClient(client *Client) *Client {
type Client struct { type Client struct {
*events *events
core Core core Core
cond *sync.Cond mutex sync.Mutex
packetPool *concurrent.Pool[*Packet] packetPool *concurrent.Pool[*Packet]
packets []*Packet packets chan *Packet
accumulate []*Packet accumulate chan *Packet
accumulation int // 积压消息数 accumulation int // 积压消息数
} }
@ -44,14 +43,12 @@ func (slf *Client) Run() error {
}() }()
err := <-runState err := <-runState
if err != nil { if err != nil {
slf.cond.L.Lock() slf.mutex.Lock()
if slf.packetPool != nil { if slf.packetPool != nil {
slf.packetPool.Close() slf.packetPool.Close()
slf.packetPool = nil slf.packetPool = nil
} }
slf.accumulate = append(slf.accumulate, slf.packets...) slf.mutex.Unlock()
slf.packets = nil
slf.cond.L.Unlock()
return err return err
} }
var wait = new(sync.WaitGroup) var wait = new(sync.WaitGroup)
@ -69,12 +66,29 @@ func (slf *Client) IsConnected() bool {
// Close 关闭 // Close 关闭
func (slf *Client) Close(err ...error) { func (slf *Client) Close(err ...error) {
slf.mutex.Lock()
var unlock bool
defer func() {
if !unlock {
slf.mutex.Unlock()
}
}()
slf.core.Close() slf.core.Close()
if slf.packetPool != nil { if slf.packetPool != nil {
slf.packetPool.Close() slf.packetPool.Close()
slf.packetPool = nil slf.packetPool = nil
} }
if slf.packets != nil {
close(slf.packets)
slf.packets = nil
}
if slf.accumulate != nil {
close(slf.accumulate)
slf.accumulate = nil
}
slf.packets = nil slf.packets = nil
unlock = true
slf.mutex.Unlock()
if len(err) > 0 { if len(err) > 0 {
slf.OnConnectionClosedEvent(slf, err[0]) slf.OnConnectionClosedEvent(slf, err[0])
} else { } else {
@ -96,7 +110,8 @@ func (slf *Client) Write(packet []byte, callback ...func(err error)) {
// write 向连接中写入数据 // write 向连接中写入数据
// - messageType: websocket模式中指定消息类型 // - messageType: websocket模式中指定消息类型
func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) { func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) {
if slf.packetPool == nil { slf.mutex.Lock()
if slf.packetPool == nil || slf.packets == nil {
var p = &Packet{ var p = &Packet{
wst: wst, wst: wst,
data: packet, data: packet,
@ -104,27 +119,26 @@ func (slf *Client) write(wst int, packet []byte, callback ...func(err error)) {
if len(callback) > 0 { if len(callback) > 0 {
p.callback = callback[0] p.callback = callback[0]
} }
slf.cond.L.Lock() if slf.accumulate == nil {
slf.accumulate = append(slf.accumulate, p) slf.accumulate = make(chan *Packet, 1024*10)
}
slf.accumulate <- p
} else {
cp := slf.packetPool.Get()
cp.wst = wst
cp.data = packet
if len(callback) > 0 {
cp.callback = callback[0]
}
slf.packets <- cp
slf.accumulation = len(slf.accumulate) + len(slf.packets) slf.accumulation = len(slf.accumulate) + len(slf.packets)
slf.cond.L.Unlock()
return
} }
cp := slf.packetPool.Get() slf.mutex.Unlock()
cp.wst = wst
cp.data = packet
if len(callback) > 0 {
cp.callback = callback[0]
}
slf.cond.L.Lock()
slf.packets = append(slf.packets, cp)
slf.accumulation = len(slf.accumulate) + len(slf.packets)
slf.cond.Signal()
slf.cond.L.Unlock()
} }
// writeLoop 写循环 // writeLoop 写循环
func (slf *Client) writeLoop(wait *sync.WaitGroup) { func (slf *Client) writeLoop(wait *sync.WaitGroup) {
slf.packets = make(chan *Packet, 1024*10)
slf.packetPool = concurrent.NewPool[*Packet](10*1024, slf.packetPool = concurrent.NewPool[*Packet](10*1024,
func() *Packet { func() *Packet {
return &Packet{} return &Packet{}
@ -134,10 +148,11 @@ func (slf *Client) writeLoop(wait *sync.WaitGroup) {
data.callback = nil data.callback = nil
}, },
) )
slf.cond.L.Lock() go func() {
slf.packets = append(slf.packets, slf.accumulate...) for packet := range slf.accumulate {
slf.accumulate = nil slf.packets <- packet
slf.cond.L.Unlock() }
}()
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
slf.Close(err.(error)) slf.Close(err.(error))
@ -145,31 +160,19 @@ func (slf *Client) writeLoop(wait *sync.WaitGroup) {
}() }()
wait.Done() wait.Done()
for { for packet := range slf.packets {
slf.cond.L.Lock() data := packet
if slf.packetPool == nil { var err = slf.core.Write(data)
slf.cond.L.Unlock() callback := data.callback
return slf.packetPool.Release(data)
if callback != nil {
callback(err)
} }
if len(slf.packets) == 0 { if err != nil {
slf.cond.Wait() panic(err)
}
packets := slf.packets[0:]
slf.packets = slf.packets[0:0]
slf.cond.L.Unlock()
for i := 0; i < len(packets); i++ {
data := packets[i]
var err = slf.core.Write(data)
callback := data.callback
slf.packetPool.Release(data)
if callback != nil {
callback(err)
}
if err != nil {
panic(err)
}
} }
} }
} }
func (slf *Client) onReceive(wst int, packet []byte) { func (slf *Client) onReceive(wst int, packet []byte) {

View File

@ -4,11 +4,9 @@ import (
"context" "context"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/utils/concurrent" "github.com/kercylan98/minotaur/utils/concurrent"
"github.com/kercylan98/minotaur/utils/log"
"github.com/panjf2000/gnet" "github.com/panjf2000/gnet"
"github.com/xtaci/kcp-go/v5" "github.com/xtaci/kcp-go/v5"
"net" "net"
"runtime/debug"
"strings" "strings"
"sync" "sync"
) )
@ -18,7 +16,8 @@ func newKcpConn(server *Server, session *kcp.UDPSession) *Conn {
c := &Conn{ c := &Conn{
ctx: server.ctx, ctx: server.ctx,
connection: &connection{ connection: &connection{
cond: sync.NewCond(&sync.Mutex{}), packets: make(chan *connPacket, 1024*10),
mutex: new(sync.Mutex),
server: server, server: server,
remoteAddr: session.RemoteAddr(), remoteAddr: session.RemoteAddr(),
ip: session.RemoteAddr().String(), ip: session.RemoteAddr().String(),
@ -41,7 +40,8 @@ func newGNetConn(server *Server, conn gnet.Conn) *Conn {
c := &Conn{ c := &Conn{
ctx: server.ctx, ctx: server.ctx,
connection: &connection{ connection: &connection{
cond: sync.NewCond(&sync.Mutex{}), packets: make(chan *connPacket, 1024*10),
mutex: new(sync.Mutex),
server: server, server: server,
remoteAddr: conn.RemoteAddr(), remoteAddr: conn.RemoteAddr(),
ip: conn.RemoteAddr().String(), ip: conn.RemoteAddr().String(),
@ -64,7 +64,8 @@ func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn {
c := &Conn{ c := &Conn{
ctx: server.ctx, ctx: server.ctx,
connection: &connection{ connection: &connection{
cond: sync.NewCond(&sync.Mutex{}), packets: make(chan *connPacket, 1024*10),
mutex: new(sync.Mutex),
server: server, server: server,
remoteAddr: ws.RemoteAddr(), remoteAddr: ws.RemoteAddr(),
ip: ip, ip: ip,
@ -84,9 +85,10 @@ func newGatewayConn(conn *Conn, connId string) *Conn {
c := &Conn{ c := &Conn{
//ctx: server.ctx, //ctx: server.ctx,
connection: &connection{ connection: &connection{
cond: sync.NewCond(&sync.Mutex{}), packets: make(chan *connPacket, 1024*10),
server: conn.server, mutex: new(sync.Mutex),
data: map[any]any{}, server: conn.server,
data: map[any]any{},
}, },
} }
c.gw = func(packet []byte) { c.gw = func(packet []byte) {
@ -100,7 +102,8 @@ func NewEmptyConn(server *Server) *Conn {
c := &Conn{ c := &Conn{
ctx: server.ctx, ctx: server.ctx,
connection: &connection{ connection: &connection{
cond: sync.NewCond(&sync.Mutex{}), packets: make(chan *connPacket, 1024*10),
mutex: new(sync.Mutex),
server: server, server: server,
remoteAddr: &net.TCPAddr{}, remoteAddr: &net.TCPAddr{},
ip: "0.0.0.0:0", ip: "0.0.0.0:0",
@ -123,6 +126,7 @@ type Conn struct {
// connection 长久保持的连接 // connection 长久保持的连接
type connection struct { type connection struct {
server *Server server *Server
mutex *sync.Mutex
remoteAddr net.Addr remoteAddr net.Addr
ip string ip string
ws *websocket.Conn ws *websocket.Conn
@ -130,9 +134,8 @@ type connection struct {
kcp *kcp.UDPSession kcp *kcp.UDPSession
gw func(packet []byte) gw func(packet []byte)
data map[any]any data map[any]any
cond *sync.Cond
packetPool *concurrent.Pool[*connPacket] packetPool *concurrent.Pool[*connPacket]
packets []*connPacket packets chan *connPacket
} }
// IsEmpty 是否是空连接 // IsEmpty 是否是空连接
@ -144,12 +147,6 @@ func (slf *Conn) IsEmpty() bool {
// - 重用连接时,会将当前连接的数据复制到新连接中 // - 重用连接时,会将当前连接的数据复制到新连接中
// - 通常在于连接断开后,重新连接时使用 // - 通常在于连接断开后,重新连接时使用
func (slf *Conn) Reuse(conn *Conn) { func (slf *Conn) Reuse(conn *Conn) {
slf.cond.L.Lock()
conn.cond.L.Lock()
defer func() {
slf.cond.L.Unlock()
conn.cond.L.Unlock()
}()
slf.Close() slf.Close()
slf.remoteAddr = conn.remoteAddr slf.remoteAddr = conn.remoteAddr
slf.ip = conn.ip slf.ip = conn.ip
@ -190,7 +187,10 @@ func (slf *Conn) Close() {
slf.packetPool.Close() slf.packetPool.Close()
} }
slf.packetPool = nil slf.packetPool = nil
slf.packets = nil if slf.packets != nil {
close(slf.packets)
slf.packets = nil
}
} }
// SetData 设置连接数据,该数据将在连接关闭前始终存在 // SetData 设置连接数据,该数据将在连接关闭前始终存在
@ -243,12 +243,14 @@ func (slf *Conn) SetWST(wst int) *Conn {
// Write 向连接中写入数据 // Write 向连接中写入数据
// - messageType: websocket模式中指定消息类型 // - messageType: websocket模式中指定消息类型
func (slf *Conn) Write(packet []byte, callback ...func(err error)) { func (slf *Conn) Write(packet []byte, callback ...func(err error)) {
slf.mutex.Lock()
defer slf.mutex.Unlock()
if slf.gw != nil { if slf.gw != nil {
slf.gw(packet) slf.gw(packet)
return return
} }
packet = slf.server.OnConnectionWritePacketBeforeEvent(slf, packet) packet = slf.server.OnConnectionWritePacketBeforeEvent(slf, packet)
if slf.packetPool == nil { if slf.packetPool == nil || slf.packets == nil {
return return
} }
cp := slf.packetPool.Get() cp := slf.packetPool.Get()
@ -257,10 +259,7 @@ func (slf *Conn) Write(packet []byte, callback ...func(err error)) {
if len(callback) > 0 { if len(callback) > 0 {
cp.callback = callback[0] cp.callback = callback[0]
} }
slf.cond.L.Lock() slf.packets <- cp
slf.packets = append(slf.packets, cp)
slf.cond.Signal()
slf.cond.L.Unlock()
} }
// writeLoop 写循环 // writeLoop 写循环
@ -277,49 +276,38 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
slf.Close() slf.Close()
log.Error("WriteLoop", log.Any("Error", err)) // TODO: 以下代码是否需要?
debug.PrintStack() // log.Error("WriteLoop", log.Any("Error", err))
// debug.PrintStack()
} }
}() }()
wait.Done() wait.Done()
for { for packet := range slf.packets {
slf.cond.L.Lock()
if slf.packetPool == nil {
slf.cond.L.Unlock()
return
}
if len(slf.packets) == 0 {
slf.cond.Wait()
}
packets := slf.packets[0:]
slf.packets = slf.packets[0:0]
slf.cond.L.Unlock()
for i := 0; i < len(packets); i++ {
data := packets[i]
var err error
if slf.IsWebsocket() {
err = slf.ws.WriteMessage(data.wst, data.packet)
} else {
if slf.gn != nil {
switch slf.server.network {
case NetworkUdp, NetworkUdp4, NetworkUdp6:
err = slf.gn.SendTo(data.packet)
default:
err = slf.gn.AsyncWrite(data.packet)
}
} else if slf.kcp != nil { data := packet
_, err = slf.kcp.Write(data.packet) var err error
if slf.IsWebsocket() {
err = slf.ws.WriteMessage(data.wst, data.packet)
} else {
if slf.gn != nil {
switch slf.server.network {
case NetworkUdp, NetworkUdp4, NetworkUdp6:
err = slf.gn.SendTo(data.packet)
default:
err = slf.gn.AsyncWrite(data.packet)
} }
} else if slf.kcp != nil {
_, err = slf.kcp.Write(data.packet)
} }
callback := data.callback }
slf.packetPool.Release(data) callback := data.callback
if callback != nil { slf.packetPool.Release(data)
callback(err)
} if callback != nil {
if err != nil { callback(err)
panic(err) }
} if err != nil {
panic(err)
} }
} }
} }

View File

@ -22,7 +22,7 @@ const (
const ( const (
DefaultMessageBufferSize = 1024 DefaultMessageBufferSize = 1024
DefaultMessageChannelSize = 1024 * 64 DefaultMessageChannelSize = 1024 * 1024
DefaultAsyncPoolSize = 256 DefaultAsyncPoolSize = 256
DefaultWebsocketReadDeadline = 30 * time.Second DefaultWebsocketReadDeadline = 30 * time.Second
) )

View File

@ -26,6 +26,7 @@ type ConnectionWritePacketBeforeEventHandle func(srv *Server, conn *Conn, packet
type ShuntChannelCreatedEventHandle func(srv *Server, guid int64) type ShuntChannelCreatedEventHandle func(srv *Server, guid int64)
type ShuntChannelClosedEventHandle func(srv *Server, guid int64) type ShuntChannelClosedEventHandle func(srv *Server, guid int64)
type ConnectionPacketPreprocessEventHandle func(srv *Server, conn *Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) type ConnectionPacketPreprocessEventHandle func(srv *Server, conn *Conn, packet []byte, abort func(), usePacket func(newPacket []byte))
type MessageExecBeforeEventHandle func(srv *Server, message *Message) bool
func newEvent(srv *Server) *event { func newEvent(srv *Server) *event {
return &event{ return &event{
@ -44,6 +45,7 @@ func newEvent(srv *Server) *event {
shuntChannelCreatedEventHandles: slice.NewPriority[ShuntChannelCreatedEventHandle](), shuntChannelCreatedEventHandles: slice.NewPriority[ShuntChannelCreatedEventHandle](),
shuntChannelClosedEventHandles: slice.NewPriority[ShuntChannelClosedEventHandle](), shuntChannelClosedEventHandles: slice.NewPriority[ShuntChannelClosedEventHandle](),
connectionPacketPreprocessEventHandles: slice.NewPriority[ConnectionPacketPreprocessEventHandle](), connectionPacketPreprocessEventHandles: slice.NewPriority[ConnectionPacketPreprocessEventHandle](),
messageExecBeforeEventHandles: slice.NewPriority[MessageExecBeforeEventHandle](),
} }
} }
@ -63,6 +65,7 @@ type event struct {
shuntChannelCreatedEventHandles *slice.Priority[ShuntChannelCreatedEventHandle] shuntChannelCreatedEventHandles *slice.Priority[ShuntChannelCreatedEventHandle]
shuntChannelClosedEventHandles *slice.Priority[ShuntChannelClosedEventHandle] shuntChannelClosedEventHandles *slice.Priority[ShuntChannelClosedEventHandle]
connectionPacketPreprocessEventHandles *slice.Priority[ConnectionPacketPreprocessEventHandle] connectionPacketPreprocessEventHandles *slice.Priority[ConnectionPacketPreprocessEventHandle]
messageExecBeforeEventHandles *slice.Priority[MessageExecBeforeEventHandle]
consoleCommandEventHandles map[string]*slice.Priority[ConsoleCommandEventHandle] consoleCommandEventHandles map[string]*slice.Priority[ConsoleCommandEventHandle]
consoleCommandEventHandleInitOnce sync.Once consoleCommandEventHandleInitOnce sync.Once
@ -353,6 +356,34 @@ func (slf *event) OnConnectionPacketPreprocessEvent(conn *Conn, packet []byte, u
return abort return abort
} }
// RegMessageExecBeforeEvent 在处理消息前将立刻执行被注册的事件处理函数
// - 当返回 true 时,将继续执行后续的消息处理函数,否则将不会执行后续的消息处理函数,并且该消息将被丢弃
//
// 适用于限流等场景
func (slf *event) RegMessageExecBeforeEvent(handle MessageExecBeforeEventHandle, priority ...int) {
slf.messageExecBeforeEventHandles.Append(handle, slice.GetValue(priority, 0))
log.Info("Server", log.String("RegEvent", runtimes.CurrentRunningFuncName()), log.String("handle", reflect.TypeOf(handle).String()))
}
// OnMessageExecBeforeEvent 执行消息处理前的事件处理函数
func (slf *event) OnMessageExecBeforeEvent(message *Message) bool {
if slf.messageExecBeforeEventHandles.Len() == 0 {
return true
}
var result = true
defer func() {
if err := recover(); err != nil {
log.Error("Server", log.String("OnMessageExecBeforeEvent", fmt.Sprintf("%v", err)))
debug.PrintStack()
}
}()
slf.messageExecBeforeEventHandles.RangeValue(func(index int, value MessageExecBeforeEventHandle) bool {
result = value(slf.Server, message)
return result
})
return result
}
func (slf *event) check() { func (slf *event) check() {
switch slf.network { switch slf.network {
case NetworkHttp, NetworkGRPC, NetworkNone: case NetworkHttp, NetworkGRPC, NetworkNone:

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/kercylan98/minotaur/utils/hash"
"reflect" "reflect"
) )
@ -64,6 +65,11 @@ type (
MessageErrorAction byte MessageErrorAction byte
) )
// HasMessageType 检查是否存在指定的消息类型
func HasMessageType(mt MessageType) bool {
return hash.Exist(messageNames, mt)
}
func (slf MessageErrorAction) String() string { func (slf MessageErrorAction) String() string {
return messageErrorActionNames[slf] return messageErrorActionNames[slf]
} }
@ -74,6 +80,11 @@ type Message struct {
attrs []any // 消息属性 attrs []any // 消息属性
} }
// MessageType 返回消息类型
func (slf *Message) MessageType() MessageType {
return slf.t
}
// String 返回消息的字符串表示 // String 返回消息的字符串表示
func (slf *Message) String() string { func (slf *Message) String() string {
var attrs = make([]any, 0, len(slf.attrs)) var attrs = make([]any, 0, len(slf.attrs))
@ -86,7 +97,9 @@ func (slf *Message) String() string {
var s string var s string
switch slf.t { switch slf.t {
case MessageTypePacket: case MessageTypePacket:
s = messagePacketVisualization(attrs[1].([]byte)) if len(attrs) > 1 {
s = messagePacketVisualization(attrs[1].([]byte))
}
default: default:
if len(slf.attrs) == 0 { if len(slf.attrs) == 0 {
s = "NoneAttr" s = "NoneAttr"
@ -104,6 +117,13 @@ func (slf MessageType) String() string {
return messageNames[slf] return messageNames[slf]
} }
// GetPacketMessageAttrs 获取消息中的数据包属性
func (slf *Message) GetPacketMessageAttrs() (conn *Conn, packet []byte) {
conn = slf.attrs[0].(*Conn)
packet = slf.attrs[1].([]byte)
return
}
// PushPacketMessage 向特定服务器中推送 MessageTypePacket 消息 // PushPacketMessage 向特定服务器中推送 MessageTypePacket 消息
func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ...any) { func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ...any) {
msg := srv.messagePool.Get() msg := srv.messagePool.Get()
@ -112,6 +132,13 @@ func PushPacketMessage(srv *Server, conn *Conn, wst int, packet []byte, mark ...
srv.pushMessage(msg) srv.pushMessage(msg)
} }
// GetErrorMessageAttrs 获取消息中的错误属性
func (slf *Message) GetErrorMessageAttrs() (err error, action MessageErrorAction) {
err = slf.attrs[0].(error)
action = slf.attrs[1].(MessageErrorAction)
return
}
// PushErrorMessage 向特定服务器中推送 MessageTypeError 消息 // PushErrorMessage 向特定服务器中推送 MessageTypeError 消息
func PushErrorMessage(srv *Server, err error, action MessageErrorAction, mark ...any) { func PushErrorMessage(srv *Server, err error, action MessageErrorAction, mark ...any) {
msg := srv.messagePool.Get() msg := srv.messagePool.Get()
@ -120,6 +147,13 @@ func PushErrorMessage(srv *Server, err error, action MessageErrorAction, mark ..
srv.pushMessage(msg) srv.pushMessage(msg)
} }
// GetCrossMessageAttrs 获取消息中的跨服属性
func (slf *Message) GetCrossMessageAttrs() (serverId int64, packet []byte) {
serverId = slf.attrs[0].(int64)
packet = slf.attrs[1].([]byte)
return
}
// PushCrossMessage 向特定服务器中推送 MessageTypeCross 消息 // PushCrossMessage 向特定服务器中推送 MessageTypeCross 消息
func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []byte, mark ...any) { func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []byte, mark ...any) {
if serverId == srv.id { if serverId == srv.id {
@ -139,6 +173,12 @@ func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []by
} }
} }
// GetTickerMessageAttrs 获取消息中的定时器属性
func (slf *Message) GetTickerMessageAttrs() (caller func()) {
caller = slf.attrs[0].(func())
return
}
// PushTickerMessage 向特定服务器中推送 MessageTypeTicker 消息 // PushTickerMessage 向特定服务器中推送 MessageTypeTicker 消息
func PushTickerMessage(srv *Server, caller func(), mark ...any) { func PushTickerMessage(srv *Server, caller func(), mark ...any) {
msg := srv.messagePool.Get() msg := srv.messagePool.Get()
@ -147,6 +187,13 @@ func PushTickerMessage(srv *Server, caller func(), mark ...any) {
srv.pushMessage(msg) srv.pushMessage(msg)
} }
// GetAsyncMessageAttrs 获取消息中的异步消息属性
func (slf *Message) GetAsyncMessageAttrs() (caller func() error, callback func(err error), hasCallback bool) {
caller = slf.attrs[0].(func() error)
callback, hasCallback = slf.attrs[1].(func(err error))
return
}
// PushAsyncMessage 向特定服务器中推送 MessageTypeAsync 消息 // PushAsyncMessage 向特定服务器中推送 MessageTypeAsync 消息
// - 异步消息将在服务器的异步消息队列中进行处理,处理完成 caller 的阻塞操作后,将会通过系统消息执行 callback 函数 // - 异步消息将在服务器的异步消息队列中进行处理,处理完成 caller 的阻塞操作后,将会通过系统消息执行 callback 函数
// - callback 函数将在异步消息处理完成后进行调用,无论过程是否产生 err都将被执行允许为 nil // - callback 函数将在异步消息处理完成后进行调用,无论过程是否产生 err都将被执行允许为 nil
@ -160,6 +207,12 @@ func PushAsyncMessage(srv *Server, caller func() error, callback func(err error)
srv.pushMessage(msg) srv.pushMessage(msg)
} }
// GetSystemMessageAttrs 获取消息中的系统消息属性
func (slf *Message) GetSystemMessageAttrs() (handle func()) {
handle = slf.attrs[0].(func())
return
}
// PushSystemMessage 向特定服务器中推送 MessageTypeSystem 消息 // PushSystemMessage 向特定服务器中推送 MessageTypeSystem 消息
func PushSystemMessage(srv *Server, handle func(), mark ...any) { func PushSystemMessage(srv *Server, handle func(), mark ...any) {
msg := srv.messagePool.Get() msg := srv.messagePool.Get()

View File

@ -102,18 +102,17 @@ type Server struct {
} }
// Run 使用特定地址运行服务器 // Run 使用特定地址运行服务器
// // - server.NetworkTcp (addr:":8888")
// server.NetworkTcp (addr:":8888") // - server.NetworkTcp4 (addr:":8888")
// server.NetworkTcp4 (addr:":8888") // - server.NetworkTcp6 (addr:":8888")
// server.NetworkTcp6 (addr:":8888") // - server.NetworkUdp (addr:":8888")
// server.NetworkUdp (addr:":8888") // - server.NetworkUdp4 (addr:":8888")
// server.NetworkUdp4 (addr:":8888") // - server.NetworkUdp6 (addr:":8888")
// server.NetworkUdp6 (addr:":8888") // - server.NetworkUnix (addr:"socketPath")
// server.NetworkUnix (addr:"socketPath") // - server.NetworkHttp (addr:":8888")
// server.NetworkHttp (addr:":8888") // - server.NetworkWebsocket (addr:":8888/ws")
// server.NetworkWebsocket (addr:":8888/ws") // - server.NetworkKcp (addr:":8888")
// server.NetworkKcp (addr:":8888") // - server.NetworkNone (addr:"")
// server.NetworkNone (addr:"")
func (slf *Server) Run(addr string) error { func (slf *Server) Run(addr string) error {
if slf.network == NetworkNone { if slf.network == NetworkNone {
addr = "-" addr = "-"
@ -142,12 +141,13 @@ func (slf *Server) Run(addr string) error {
if callback != nil { if callback != nil {
go callback() go callback()
} }
go func() { go func(messageChannel <-chan *Message) {
messageInitFinish <- struct{}{} messageInitFinish <- struct{}{}
for message := range slf.messageChannel { for message := range messageChannel {
slf.dispatchMessage(message) msg := message
slf.dispatchMessage(msg)
} }
}() }(slf.messageChannel)
} }
switch slf.network { switch slf.network {
@ -364,6 +364,16 @@ func (slf *Server) RunNone() error {
return slf.Run(str.None) return slf.Run(str.None)
} }
// Context 获取服务器上下文
func (slf *Server) Context() context.Context {
return slf.ctx
}
// TimeoutContext 获取服务器超时上下文context.WithTimeout 的简写
func (slf *Server) TimeoutContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(slf.ctx, timeout)
}
// GetOnlineCount 获取在线人数 // GetOnlineCount 获取在线人数
func (slf *Server) GetOnlineCount() int { func (slf *Server) GetOnlineCount() int {
return slf.online.Size() return slf.online.Size()
@ -541,13 +551,10 @@ func (slf *Server) ShuntChannelFreed(channelGuid int64) {
// pushMessage 向服务器中写入特定类型的消息,需严格遵守消息属性要求 // pushMessage 向服务器中写入特定类型的消息,需严格遵守消息属性要求
func (slf *Server) pushMessage(message *Message) { func (slf *Server) pushMessage(message *Message) {
if slf.messagePool.IsClose() { if slf.messagePool.IsClose() || slf.isShutdown.Load() || !slf.OnMessageExecBeforeEvent(message) {
slf.messagePool.Release(message) slf.messagePool.Release(message)
return return
} }
if slf.isShutdown.Load() {
return
}
if slf.shuntChannels != nil && message.t == MessageTypePacket { if slf.shuntChannels != nil && message.t == MessageTypePacket {
conn := message.attrs[0].(*Conn) conn := message.attrs[0].(*Conn)
channelGuid, allowToCreate := slf.shuntMatcher(conn) channelGuid, allowToCreate := slf.shuntMatcher(conn)
@ -568,6 +575,7 @@ func (slf *Server) pushMessage(message *Message) {
} }
} }
slf.messageChannel <- message slf.messageChannel <- message
} }
func (slf *Server) low(message *Message, present time.Time, expect time.Duration, messageReplace ...string) { func (slf *Server) low(message *Message, present time.Time, expect time.Duration, messageReplace ...string) {
@ -593,18 +601,18 @@ func (slf *Server) dispatchMessage(msg *Message) {
) )
if slf.deadlockDetect > 0 { if slf.deadlockDetect > 0 {
ctx, cancel = context.WithTimeout(context.Background(), slf.deadlockDetect) ctx, cancel = context.WithTimeout(context.Background(), slf.deadlockDetect)
go func() { go func(ctx context.Context, msg *Message) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
if err := ctx.Err(); err == context.DeadlineExceeded { if err := ctx.Err(); err == context.DeadlineExceeded {
log.Warn("Server", log.String("MessageType", messageNames[msg.t]), log.Any("SuspectedDeadlock", msg.attrs)) log.Warn("Server", log.String("MessageType", messageNames[msg.t]), log.Any("SuspectedDeadlock", msg.attrs))
} }
} }
}() }(ctx, msg)
} }
present := time.Now() present := time.Now()
defer func() { defer func(msg *Message) {
if err := recover(); err != nil { if err := recover(); err != nil {
stack := string(debug.Stack()) stack := string(debug.Stack())
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("MessageAttrs", msg.attrs), log.Any("error", err), log.String("stack", stack)) log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("MessageAttrs", msg.attrs), log.Any("error", err), log.String("stack", stack))
@ -626,17 +634,16 @@ func (slf *Server) dispatchMessage(msg *Message) {
slf.messagePool.Release(msg) slf.messagePool.Release(msg)
} }
}() }(msg)
var attrs = msg.attrs var attrs = msg.attrs
switch msg.t { switch msg.t {
case MessageTypePacket: case MessageTypePacket:
var conn = attrs[0].(*Conn) var conn, packet = msg.GetPacketMessageAttrs()
var packet = attrs[1].([]byte)
if !slf.OnConnectionPacketPreprocessEvent(conn, packet, func(newPacket []byte) { packet = newPacket }) { if !slf.OnConnectionPacketPreprocessEvent(conn, packet, func(newPacket []byte) { packet = newPacket }) {
slf.OnConnectionReceivePacketEvent(conn, packet) slf.OnConnectionReceivePacketEvent(conn, packet)
} }
case MessageTypeError: case MessageTypeError:
err, action := attrs[0].(error), attrs[1].(MessageErrorAction) var err, action = msg.GetErrorMessageAttrs()
switch action { switch action {
case MessageErrorActionNone: case MessageErrorActionNone:
log.Panic("Server", log.Err(err)) log.Panic("Server", log.Err(err))
@ -646,12 +653,11 @@ func (slf *Server) dispatchMessage(msg *Message) {
log.Warn("Server", log.String("not support message error action", action.String())) log.Warn("Server", log.String("not support message error action", action.String()))
} }
case MessageTypeCross: case MessageTypeCross:
slf.OnReceiveCrossPacketEvent(attrs[0].(int64), attrs[1].([]byte)) slf.OnReceiveCrossPacketEvent(msg.GetCrossMessageAttrs())
case MessageTypeTicker: case MessageTypeTicker:
attrs[0].(func())() msg.GetTickerMessageAttrs()()
case MessageTypeAsync: case MessageTypeAsync:
handle := attrs[0].(func() error) handle, callback, cb := msg.GetAsyncMessageAttrs()
callback, cb := attrs[1].(func(err error))
if err := slf.ants.Submit(func() { if err := slf.ants.Submit(func() {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
@ -686,10 +692,10 @@ func (slf *Server) dispatchMessage(msg *Message) {
}); err != nil { }); err != nil {
panic(err) panic(err)
} }
case MessageTypeAsyncCallback: case MessageTypeAsyncCallback: // 特殊类型
attrs[0].(func())() attrs[0].(func())()
case MessageTypeSystem: case MessageTypeSystem:
attrs[0].(func())() msg.GetSystemMessageAttrs()()
default: default:
log.Warn("Server", log.String("not support message type", msg.t.String())) log.Warn("Server", log.String("not support message type", msg.t.String()))
} }

54
server/server_test.go Normal file
View File

@ -0,0 +1,54 @@
package server_test
import (
"fmt"
"github.com/kercylan98/minotaur/server"
"github.com/kercylan98/minotaur/server/client"
"github.com/kercylan98/minotaur/utils/times"
"golang.org/x/time/rate"
"sync/atomic"
"testing"
"time"
)
func TestNew(t *testing.T) {
limiter := rate.NewLimiter(rate.Every(time.Second), 100)
srv := server.New(server.NetworkWebsocket, server.WithMessageBufferSize(1024*1024), server.WithPProf())
srv.RegMessageExecBeforeEvent(func(srv *server.Server, message *server.Message) bool {
t, c := srv.TimeoutContext(time.Second * 5)
defer c()
if err := limiter.Wait(t); err != nil {
return false
}
return true
})
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
conn.Write(packet)
})
if err := srv.Run(":9999"); err != nil {
panic(err)
}
}
func TestNewClient(t *testing.T) {
var total atomic.Int64
for i := 0; i < 1000; i++ {
cli := client.NewWebsocket("ws://127.0.0.1:9999")
cli.RegConnectionReceivePacketEvent(func(conn *client.Client, wst int, packet []byte) {
fmt.Println(string(packet))
})
cli.RegConnectionOpenedEvent(func(conn *client.Client) {
go func() {
for {
cli.WriteWS(2, []byte("hello"))
total.Add(1)
}
}()
})
if err := cli.Run(); err != nil {
panic(err)
}
}
time.Sleep(times.Week)
}