style: 优化 server 包代码可读性

This commit is contained in:
kercylan98 2023-12-25 17:40:55 +08:00
parent 7ee4b893cd
commit af0a5a1c25
9 changed files with 691 additions and 586 deletions

View File

@ -16,23 +16,27 @@ import (
)
type (
StartBeforeEventHandler func(srv *Server)
StartFinishEventHandler func(srv *Server)
StopEventHandler func(srv *Server)
ConnectionReceivePacketEventHandler func(srv *Server, conn *Conn, packet []byte)
MessageReadyEventHandler func(srv *Server)
StartBeforeEventHandler func(srv *Server)
StartFinishEventHandler func(srv *Server)
StopEventHandler func(srv *Server)
ConnectionOpenedEventHandler func(srv *Server, conn *Conn)
ConnectionClosedEventHandler func(srv *Server, conn *Conn, err any)
MessageErrorEventHandler func(srv *Server, message *Message, err error)
MessageLowExecEventHandler func(srv *Server, message *Message, cost time.Duration)
ConsoleCommandEventHandler func(srv *Server, command string, params ConsoleParams)
ConnectionOpenedAfterEventHandler func(srv *Server, conn *Conn)
ConnectionWritePacketBeforeEventHandler func(srv *Server, conn *Conn, packet []byte) []byte
ShuntChannelCreatedEventHandler func(srv *Server, guid int64)
ShuntChannelClosedEventHandler func(srv *Server, guid int64)
ConnectionPacketPreprocessEventHandler func(srv *Server, conn *Conn, packet []byte, abort func(), usePacket func(newPacket []byte))
MessageExecBeforeEventHandler func(srv *Server, message *Message) bool
MessageReadyEventHandler func(srv *Server)
OnDeadlockDetectEventHandler func(srv *Server, message *Message)
ConnectionReceivePacketEventHandler func(srv *Server, conn *Conn, packet []byte)
ConnectionWritePacketBeforeEventHandler func(srv *Server, conn *Conn, packet []byte) []byte
ConnectionClosedEventHandler func(srv *Server, conn *Conn, err any)
ShuntChannelCreatedEventHandler func(srv *Server, name string)
ShuntChannelClosedEventHandler func(srv *Server, name string)
MessageExecBeforeEventHandler func(srv *Server, message *Message) bool
MessageLowExecEventHandler func(srv *Server, message *Message, cost time.Duration)
MessageErrorEventHandler func(srv *Server, message *Message, err error)
ConsoleCommandEventHandler func(srv *Server, command string, params ConsoleParams)
OnDeadlockDetectEventHandler func(srv *Server, message *Message)
)
func newEvent(srv *Server) *event {
@ -207,7 +211,7 @@ func (slf *event) RegConnectionClosedEvent(handler ConnectionClosedEventHandler,
func (slf *event) OnConnectionClosedEvent(conn *Conn, err any) {
slf.PushShuntMessage(conn, func() {
slf.Server.online.Delete(conn.GetID())
slf.Server.online.Del(conn.GetID())
slf.connectionClosedEventHandlers.RangeValue(func(index int, value ConnectionClosedEventHandler) bool {
value(slf.Server, conn, err)
return true
@ -340,10 +344,10 @@ func (slf *event) RegShuntChannelCreatedEvent(handler ShuntChannelCreatedEventHa
log.Info("Server", log.String("RegEvent", runtimes.CurrentRunningFuncName()), log.String("handler", reflect.TypeOf(handler).String()))
}
func (slf *event) OnShuntChannelCreatedEvent(guid int64) {
func (slf *event) OnShuntChannelCreatedEvent(name string) {
slf.PushSystemMessage(func() {
slf.shuntChannelCreatedEventHandlers.RangeValue(func(index int, value ShuntChannelCreatedEventHandler) bool {
value(slf.Server, guid)
value(slf.Server, name)
return true
})
}, log.String("Event", "OnShuntChannelCreatedEvent"))
@ -355,10 +359,10 @@ func (slf *event) RegShuntChannelCloseEvent(handler ShuntChannelClosedEventHandl
log.Info("Server", log.String("RegEvent", runtimes.CurrentRunningFuncName()), log.String("handler", reflect.TypeOf(handler).String()))
}
func (slf *event) OnShuntChannelClosedEvent(guid int64) {
func (slf *event) OnShuntChannelClosedEvent(name string) {
slf.PushSystemMessage(func() {
slf.shuntChannelClosedEventHandlers.RangeValue(func(index int, value ShuntChannelClosedEventHandler) bool {
value(slf.Server, guid)
value(slf.Server, name)
return true
})
}, log.String("Event", "OnShuntChannelClosedEvent"))
@ -460,13 +464,3 @@ func (slf *event) OnDeadlockDetectEvent(message *Message) {
return true
})
}
func (slf *event) check() {
switch slf.network {
case NetworkHttp, NetworkGRPC, NetworkNone:
default:
if slf.connectionReceivePacketEventHandlers.Len() == 0 {
log.Warn("Server", log.String("ConnectionReceivePacketEvent", "invalid server, no packets processed"))
}
}
}

View File

@ -8,42 +8,47 @@ import (
type gNet struct {
*Server
state chan<- error
}
func (slf *gNet) OnInitComplete(server gnet.Server) (action gnet.Action) {
func (g *gNet) OnInitComplete(server gnet.Server) (action gnet.Action) {
if g.state != nil {
g.state <- nil
g.state = nil
}
return
}
func (slf *gNet) OnShutdown(server gnet.Server) {
func (g *gNet) OnShutdown(server gnet.Server) {
return
}
func (slf *gNet) OnOpened(c gnet.Conn) (out []byte, action gnet.Action) {
conn := newGNetConn(slf.Server, c)
func (g *gNet) OnOpened(c gnet.Conn) (out []byte, action gnet.Action) {
conn := newGNetConn(g.Server, c)
c.SetContext(conn)
slf.OnConnectionOpenedEvent(conn)
g.OnConnectionOpenedEvent(conn)
return
}
func (slf *gNet) OnClosed(c gnet.Conn, err error) (action gnet.Action) {
func (g *gNet) OnClosed(c gnet.Conn, err error) (action gnet.Action) {
conn := c.Context().(*Conn)
conn.Close(err)
return
}
func (slf *gNet) PreWrite(c gnet.Conn) {
func (g *gNet) PreWrite(c gnet.Conn) {
return
}
func (slf *gNet) AfterWrite(c gnet.Conn, b []byte) {
func (g *gNet) AfterWrite(c gnet.Conn, b []byte) {
return
}
func (slf *gNet) React(packet []byte, c gnet.Conn) (out []byte, action gnet.Action) {
slf.Server.PushPacketMessage(c.Context().(*Conn), 0, bytes.Clone(packet))
func (g *gNet) React(packet []byte, c gnet.Conn) (out []byte, action gnet.Action) {
g.Server.PushPacketMessage(c.Context().(*Conn), 0, bytes.Clone(packet))
return nil, gnet.None
}
func (slf *gNet) Tick() (delay time.Duration, action gnet.Action) {
func (g *gNet) Tick() (delay time.Duration, action gnet.Action) {
return
}

34
server/listener.go Normal file
View File

@ -0,0 +1,34 @@
package server
import (
"github.com/xtaci/kcp-go/v5"
"net"
"sync"
)
type listener struct {
srv *Server
once sync.Once
net.Listener
kcpListener *kcp.Listener
state chan<- error
}
func (l *listener) init() *listener {
l.srv.OnStartBeforeEvent()
return l
}
func (l *listener) Accept() (net.Conn, error) {
l.once.Do(func() {
l.state <- nil
})
return l.Listener.Accept()
}
func (l *listener) AcceptKCP() (*kcp.UDPSession, error) {
l.once.Do(func() {
l.state <- nil
})
return l.kcpListener.AcceptKCP()
}

View File

@ -10,9 +10,6 @@ const (
// MessageTypePacket 数据包消息类型:该类型的数据将被发送到 ConnectionReceivePacketEvent 进行处理
MessageTypePacket MessageType = iota + 1
// MessageTypeError 错误消息类型:根据不同的错误状态,将交由 Server 进行统一处理
MessageTypeError
// MessageTypeTicker 定时器消息类型
MessageTypeTicker
@ -52,7 +49,6 @@ const (
var messageNames = map[MessageType]string{
MessageTypePacket: "MessageTypePacket",
MessageTypeError: "MessageTypeError",
MessageTypeTicker: "MessageTypeTicker",
MessageTypeShuntTicker: "MessageTypeShuntTicker",
MessageTypeAsync: "MessageTypeAsync",
@ -67,22 +63,9 @@ var messageNames = map[MessageType]string{
MessageTypeShunt: "MessageTypeShunt",
}
const (
MessageErrorActionNone MessageErrorAction = iota + 1 // 错误消息类型操作:将不会被进行任何特殊处理,仅进行日志输出
MessageErrorActionShutdown // 错误消息类型操作:当接收到该类型的操作时,服务器将执行 Server.shutdown 函数
)
var messageErrorActionNames = map[MessageErrorAction]string{
MessageErrorActionNone: "None",
MessageErrorActionShutdown: "Shutdown",
}
type (
// MessageType 消息类型
MessageType byte
// MessageErrorAction 错误消息类型操作
MessageErrorAction byte
)
// HasMessageType 检查是否存在指定的消息类型
@ -90,10 +73,6 @@ func HasMessageType(mt MessageType) bool {
return hash.Exist(messageNames, mt)
}
func (slf MessageErrorAction) String() string {
return messageErrorActionNames[slf]
}
// Message 服务器消息
type Message struct {
conn *Conn
@ -104,7 +83,6 @@ type Message struct {
err error
name string
t MessageType
errAction MessageErrorAction
marks []log.Field
}
@ -118,7 +96,6 @@ func (slf *Message) reset() {
slf.err = nil
slf.name = ""
slf.t = 0
slf.errAction = 0
slf.marks = nil
}
@ -219,12 +196,6 @@ func (slf *Message) castToSystemMessage(caller func(), mark ...log.Field) *Messa
return slf
}
// castToErrorMessage 将消息转换为错误消息
func (slf *Message) castToErrorMessage(err error, action MessageErrorAction, mark ...log.Field) *Message {
slf.t, slf.err, slf.errAction, slf.marks = MessageTypeError, err, action, mark
return slf
}
// castToShuntMessage 将消息转换为分流消息
func (slf *Message) castToShuntMessage(conn *Conn, caller func(), mark ...log.Field) *Message {
slf.t, slf.conn, slf.ordinaryHandler, slf.marks = MessageTypeShunt, conn, caller, mark

View File

@ -1,8 +1,6 @@
package server
import (
"github.com/kercylan98/minotaur/utils/log"
"github.com/kercylan98/minotaur/utils/network"
"github.com/xtaci/kcp-go/v5"
"math"
"os"
@ -71,16 +69,7 @@ func (slf *MultipleServer) Run() {
kcp.SystemTimedSched.Close()
}
log.Info("Server", log.String(serverMultipleMark, "===================================================================="))
ip, _ := network.IP()
for _, server := range slf.servers {
log.Info("Server", log.String(serverMultipleMark, "RunningInfo"),
log.Any("network", server.network),
log.String("ip", ip.String()),
log.String("listen", server.addr),
)
}
log.Info("Server", log.String(serverMultipleMark, "===================================================================="))
ShowServersInfo(serverMultipleMark, slf.servers...)
systemSignal := make(chan os.Signal, 1)
signal.Notify(systemSignal, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)

View File

@ -1,6 +1,21 @@
package server
import "github.com/kercylan98/minotaur/utils/slice"
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/kercylan98/minotaur/server/internal/logger"
"github.com/kercylan98/minotaur/utils/hash"
"github.com/kercylan98/minotaur/utils/log"
"github.com/kercylan98/minotaur/utils/slice"
"github.com/kercylan98/minotaur/utils/super"
"github.com/panjf2000/gnet"
"github.com/xtaci/kcp-go/v5"
"google.golang.org/grpc"
"net"
"net/http"
"strings"
"time"
)
type Network string
@ -23,12 +38,271 @@ const (
)
var (
networks = []Network{
networkNameMap map[string]struct{}
networks = []Network{
NetworkNone, NetworkTcp, NetworkTcp4, NetworkTcp6, NetworkUdp, NetworkUdp4, NetworkUdp6, NetworkUnix, NetworkHttp, NetworkWebsocket, NetworkKcp, NetworkGRPC,
}
)
func init() {
networkNameMap = make(map[string]struct{}, len(networks))
for _, network := range networks {
networkNameMap[string(network)] = struct{}{}
}
}
// GetNetworks 获取所有支持的网络模式
func GetNetworks() []Network {
return slice.Copy(networks)
}
// check 检查网络模式是否支持
func (n Network) check() {
if !hash.Exist(networkNameMap, string(n)) {
panic(fmt.Errorf("unsupported network mode: %s", n))
}
}
// preprocessing 服务器预处理
func (n Network) preprocessing(srv *Server) {
switch n {
case NetworkNone:
case NetworkTcp:
case NetworkTcp4:
case NetworkTcp6:
case NetworkUdp:
case NetworkUdp4:
case NetworkUdp6:
case NetworkUnix:
case NetworkHttp:
srv.ginServer = gin.New()
srv.httpServer = &http.Server{
Handler: srv.ginServer,
}
case NetworkWebsocket:
srv.websocketReadDeadline = DefaultWebsocketReadDeadline
case NetworkKcp:
case NetworkGRPC:
srv.grpcServer = grpc.NewServer()
}
}
// adaptation 服务器适配
func (n Network) adaptation(srv *Server) <-chan error {
state := make(chan error, 1)
switch n {
case NetworkNone:
srv.addr = "-"
case NetworkTcp:
n.gNetMode(state, srv)
case NetworkTcp4:
n.gNetMode(state, srv)
case NetworkTcp6:
n.gNetMode(state, srv)
case NetworkUdp:
n.gNetMode(state, srv)
case NetworkUdp4:
n.gNetMode(state, srv)
case NetworkUdp6:
n.gNetMode(state, srv)
case NetworkUnix:
n.gNetMode(state, srv)
case NetworkHttp:
n.httpMode(state, srv)
case NetworkWebsocket:
n.websocketMode(state, srv)
case NetworkKcp:
n.kcpMode(state, srv)
case NetworkGRPC:
n.grpcMode(state, srv)
}
return state
}
// gNetMode gNet模式
func (n Network) gNetMode(state chan<- error, srv *Server) {
srv.gServer = &gNet{Server: srv, state: state}
go func(srv *Server) {
if err := gnet.Serve(srv.gServer, fmt.Sprintf("%s://%s", srv.network, srv.addr),
gnet.WithLogger(new(logger.GNet)),
gnet.WithTicker(true),
gnet.WithMulticore(true),
); err != nil {
srv.gServer.state <- err
}
}(srv)
}
// grpcMode grpc模式
func (n Network) grpcMode(state chan<- error, srv *Server) {
l, err := net.Listen(string(NetworkTcp), srv.addr)
if err != nil {
state <- err
return
}
lis := (&listener{srv: srv, Listener: l, state: state}).init()
go func(srv *Server, lis *listener) {
if err = srv.grpcServer.Serve(lis); err != nil {
lis.state <- err
}
}(srv, lis)
}
// kcpMode kcp模式
func (n Network) kcpMode(state chan<- error, srv *Server) {
l, err := kcp.ListenWithOptions(srv.addr, nil, 0, 0)
if err != nil {
state <- err
return
}
lis := (&listener{srv: srv, kcpListener: l, state: state}).init()
go func(lis *listener) {
for {
session, err := lis.AcceptKCP()
if err != nil {
continue
}
conn := newKcpConn(lis.srv, session)
lis.srv.OnConnectionOpenedEvent(conn)
go func(conn *Conn) {
defer func() {
if err := super.RecoverTransform(recover()); err != nil {
conn.Close(err)
}
}()
buf := make([]byte, 4096)
for !conn.IsClosed() {
n, err := conn.kcp.Read(buf)
if err != nil {
if conn.IsClosed() {
break
}
panic(err)
}
lis.srv.PushPacketMessage(conn, 0, buf[:n])
}
}(conn)
}
}(lis)
return
}
// httpMode http模式
func (n Network) httpMode(state chan<- error, srv *Server) {
srv.httpServer.Addr = srv.addr
l, err := net.Listen(string(NetworkTcp), srv.addr)
if err != nil {
state <- err
return
}
gin.SetMode(gin.ReleaseMode)
srv.ginServer.Use(func(c *gin.Context) {
t := time.Now()
c.Next()
log.Info("Server", log.String("type", "http"),
log.String("method", c.Request.Method), log.Int("status", c.Writer.Status()),
log.String("ip", c.ClientIP()), log.String("path", c.Request.URL.Path),
log.Duration("cost", time.Since(t)))
})
go func(lis *listener) {
var err error
if len(lis.srv.certFile)+len(srv.keyFile) > 0 {
err = lis.srv.httpServer.ServeTLS(lis, lis.srv.certFile, lis.srv.keyFile)
} else {
err = lis.srv.httpServer.Serve(lis)
}
if err != nil {
lis.state <- err
}
}((&listener{srv: srv, Listener: l, state: state}).init())
}
// websocketMode websocket模式
func (n Network) websocketMode(state chan<- error, srv *Server) {
l, err := net.Listen(string(NetworkTcp), srv.addr)
if err != nil {
state <- err
return
}
var pattern string
var index = strings.Index(srv.addr, "/")
if index == -1 {
pattern = "/"
} else {
pattern = srv.addr[index:]
srv.addr = srv.addr[:index]
}
if srv.websocketUpgrader == nil {
srv.websocketUpgrader = DefaultWebsocketUpgrader()
}
http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) {
ip := request.Header.Get("X-Real-IP")
ws, err := srv.websocketUpgrader.Upgrade(writer, request, nil)
if err != nil {
return
}
if srv.websocketConnInitializer != nil {
if err = srv.websocketConnInitializer(writer, request, ws); err != nil {
return
}
}
if len(ip) == 0 {
addr := ws.RemoteAddr().String()
if index := strings.LastIndex(addr, ":"); index != -1 {
ip = addr[0:index]
}
}
if srv.websocketCompression > 0 {
_ = ws.SetCompressionLevel(srv.websocketCompression)
}
ws.EnableWriteCompression(srv.websocketWriteCompression)
conn := newWebsocketConn(srv, ws, ip)
conn.SetData(wsRequestKey, request)
for k, v := range request.URL.Query() {
if len(v) == 1 {
conn.SetData(k, v[0])
} else {
conn.SetData(k, v)
}
}
srv.OnConnectionOpenedEvent(conn)
defer func() {
if err := super.RecoverTransform(recover()); err != nil {
conn.Close(err)
}
}()
for !conn.IsClosed() {
if srv.websocketReadDeadline > 0 {
if err := ws.SetReadDeadline(time.Now().Add(srv.websocketReadDeadline)); err != nil {
panic(err)
}
}
messageType, packet, readErr := ws.ReadMessage()
if readErr != nil {
if conn.IsClosed() {
break
}
panic(readErr)
}
if len(srv.supportMessageTypes) > 0 && !srv.supportMessageTypes[messageType] {
panic(ErrWebsocketIllegalMessageType)
}
srv.PushPacketMessage(conn, messageType, packet)
}
})
go func(lis *listener) {
var err error
if len(lis.srv.certFile)+len(lis.srv.keyFile) > 0 {
err = http.ServeTLS(lis, nil, lis.srv.certFile, lis.srv.keyFile)
} else {
err = http.Serve(lis, nil)
}
if err != nil {
lis.state <- err
}
}((&listener{srv: srv, Listener: l, state: state}).init())
}

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,32 @@ import (
func TestNew(t *testing.T) {
srv := server.New(server.NetworkWebsocket, server.WithPProf())
srv.RegStartBeforeEvent(func(srv *server.Server) {
fmt.Println("启动前")
})
srv.RegStartFinishEvent(func(srv *server.Server) {
fmt.Println("启动完成")
})
srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) {
fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount())
})
srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) {
conn.Write(packet)
})
if err := srv.Run(":9999"); err != nil {
panic(err)
}
}
func TestNew2(t *testing.T) {
srv := server.New(server.NetworkWebsocket, server.WithPProf())
srv.RegStartBeforeEvent(func(srv *server.Server) {
fmt.Println("启动前")
})
srv.RegStartFinishEvent(func(srv *server.Server) {
fmt.Println("启动完成")
})
srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) {
fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount())
})

View File

@ -28,10 +28,3 @@ func BindService(srv *Server, services ...Service) {
})
}
}
// onServicesInit 服务初始化
func onServicesInit(srv *Server) {
for _, service := range srv.services {
service()
}
}