From af0a5a1c25de46d2ea03abf777a40bd2723536c3 Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Mon, 25 Dec 2023 17:40:55 +0800 Subject: [PATCH] =?UTF-8?q?style:=20=E4=BC=98=E5=8C=96=20server=20?= =?UTF-8?q?=E5=8C=85=E4=BB=A3=E7=A0=81=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/event.go | 52 ++- server/gnet.go | 27 +- server/listener.go | 34 ++ server/message.go | 29 -- server/multiple.go | 13 +- server/network.go | 278 ++++++++++++++- server/server.go | 811 ++++++++++++++++-------------------------- server/server_test.go | 26 ++ server/service.go | 7 - 9 files changed, 691 insertions(+), 586 deletions(-) create mode 100644 server/listener.go diff --git a/server/event.go b/server/event.go index bfd77f8..9471438 100644 --- a/server/event.go +++ b/server/event.go @@ -16,23 +16,27 @@ import ( ) type ( - StartBeforeEventHandler func(srv *Server) - StartFinishEventHandler func(srv *Server) - StopEventHandler func(srv *Server) - ConnectionReceivePacketEventHandler func(srv *Server, conn *Conn, packet []byte) + MessageReadyEventHandler func(srv *Server) + StartBeforeEventHandler func(srv *Server) + StartFinishEventHandler func(srv *Server) + StopEventHandler func(srv *Server) + ConnectionOpenedEventHandler func(srv *Server, conn *Conn) - ConnectionClosedEventHandler func(srv *Server, conn *Conn, err any) - MessageErrorEventHandler func(srv *Server, message *Message, err error) - MessageLowExecEventHandler func(srv *Server, message *Message, cost time.Duration) - ConsoleCommandEventHandler func(srv *Server, command string, params ConsoleParams) ConnectionOpenedAfterEventHandler func(srv *Server, conn *Conn) - ConnectionWritePacketBeforeEventHandler func(srv *Server, conn *Conn, packet []byte) []byte - ShuntChannelCreatedEventHandler func(srv *Server, guid int64) - ShuntChannelClosedEventHandler func(srv *Server, guid int64) ConnectionPacketPreprocessEventHandler func(srv *Server, conn *Conn, packet []byte, abort func(), usePacket func(newPacket []byte)) - MessageExecBeforeEventHandler func(srv *Server, message *Message) bool - MessageReadyEventHandler func(srv *Server) - OnDeadlockDetectEventHandler func(srv *Server, message *Message) + ConnectionReceivePacketEventHandler func(srv *Server, conn *Conn, packet []byte) + ConnectionWritePacketBeforeEventHandler func(srv *Server, conn *Conn, packet []byte) []byte + ConnectionClosedEventHandler func(srv *Server, conn *Conn, err any) + + ShuntChannelCreatedEventHandler func(srv *Server, name string) + ShuntChannelClosedEventHandler func(srv *Server, name string) + + MessageExecBeforeEventHandler func(srv *Server, message *Message) bool + MessageLowExecEventHandler func(srv *Server, message *Message, cost time.Duration) + MessageErrorEventHandler func(srv *Server, message *Message, err error) + + ConsoleCommandEventHandler func(srv *Server, command string, params ConsoleParams) + OnDeadlockDetectEventHandler func(srv *Server, message *Message) ) func newEvent(srv *Server) *event { @@ -207,7 +211,7 @@ func (slf *event) RegConnectionClosedEvent(handler ConnectionClosedEventHandler, func (slf *event) OnConnectionClosedEvent(conn *Conn, err any) { slf.PushShuntMessage(conn, func() { - slf.Server.online.Delete(conn.GetID()) + slf.Server.online.Del(conn.GetID()) slf.connectionClosedEventHandlers.RangeValue(func(index int, value ConnectionClosedEventHandler) bool { value(slf.Server, conn, err) return true @@ -340,10 +344,10 @@ func (slf *event) RegShuntChannelCreatedEvent(handler ShuntChannelCreatedEventHa log.Info("Server", log.String("RegEvent", runtimes.CurrentRunningFuncName()), log.String("handler", reflect.TypeOf(handler).String())) } -func (slf *event) OnShuntChannelCreatedEvent(guid int64) { +func (slf *event) OnShuntChannelCreatedEvent(name string) { slf.PushSystemMessage(func() { slf.shuntChannelCreatedEventHandlers.RangeValue(func(index int, value ShuntChannelCreatedEventHandler) bool { - value(slf.Server, guid) + value(slf.Server, name) return true }) }, log.String("Event", "OnShuntChannelCreatedEvent")) @@ -355,10 +359,10 @@ func (slf *event) RegShuntChannelCloseEvent(handler ShuntChannelClosedEventHandl log.Info("Server", log.String("RegEvent", runtimes.CurrentRunningFuncName()), log.String("handler", reflect.TypeOf(handler).String())) } -func (slf *event) OnShuntChannelClosedEvent(guid int64) { +func (slf *event) OnShuntChannelClosedEvent(name string) { slf.PushSystemMessage(func() { slf.shuntChannelClosedEventHandlers.RangeValue(func(index int, value ShuntChannelClosedEventHandler) bool { - value(slf.Server, guid) + value(slf.Server, name) return true }) }, log.String("Event", "OnShuntChannelClosedEvent")) @@ -460,13 +464,3 @@ func (slf *event) OnDeadlockDetectEvent(message *Message) { return true }) } - -func (slf *event) check() { - switch slf.network { - case NetworkHttp, NetworkGRPC, NetworkNone: - default: - if slf.connectionReceivePacketEventHandlers.Len() == 0 { - log.Warn("Server", log.String("ConnectionReceivePacketEvent", "invalid server, no packets processed")) - } - } -} diff --git a/server/gnet.go b/server/gnet.go index 77f8d60..7b9b5d6 100644 --- a/server/gnet.go +++ b/server/gnet.go @@ -8,42 +8,47 @@ import ( type gNet struct { *Server + state chan<- error } -func (slf *gNet) OnInitComplete(server gnet.Server) (action gnet.Action) { +func (g *gNet) OnInitComplete(server gnet.Server) (action gnet.Action) { + if g.state != nil { + g.state <- nil + g.state = nil + } return } -func (slf *gNet) OnShutdown(server gnet.Server) { +func (g *gNet) OnShutdown(server gnet.Server) { return } -func (slf *gNet) OnOpened(c gnet.Conn) (out []byte, action gnet.Action) { - conn := newGNetConn(slf.Server, c) +func (g *gNet) OnOpened(c gnet.Conn) (out []byte, action gnet.Action) { + conn := newGNetConn(g.Server, c) c.SetContext(conn) - slf.OnConnectionOpenedEvent(conn) + g.OnConnectionOpenedEvent(conn) return } -func (slf *gNet) OnClosed(c gnet.Conn, err error) (action gnet.Action) { +func (g *gNet) OnClosed(c gnet.Conn, err error) (action gnet.Action) { conn := c.Context().(*Conn) conn.Close(err) return } -func (slf *gNet) PreWrite(c gnet.Conn) { +func (g *gNet) PreWrite(c gnet.Conn) { return } -func (slf *gNet) AfterWrite(c gnet.Conn, b []byte) { +func (g *gNet) AfterWrite(c gnet.Conn, b []byte) { return } -func (slf *gNet) React(packet []byte, c gnet.Conn) (out []byte, action gnet.Action) { - slf.Server.PushPacketMessage(c.Context().(*Conn), 0, bytes.Clone(packet)) +func (g *gNet) React(packet []byte, c gnet.Conn) (out []byte, action gnet.Action) { + g.Server.PushPacketMessage(c.Context().(*Conn), 0, bytes.Clone(packet)) return nil, gnet.None } -func (slf *gNet) Tick() (delay time.Duration, action gnet.Action) { +func (g *gNet) Tick() (delay time.Duration, action gnet.Action) { return } diff --git a/server/listener.go b/server/listener.go new file mode 100644 index 0000000..e75ee92 --- /dev/null +++ b/server/listener.go @@ -0,0 +1,34 @@ +package server + +import ( + "github.com/xtaci/kcp-go/v5" + "net" + "sync" +) + +type listener struct { + srv *Server + once sync.Once + net.Listener + kcpListener *kcp.Listener + state chan<- error +} + +func (l *listener) init() *listener { + l.srv.OnStartBeforeEvent() + return l +} + +func (l *listener) Accept() (net.Conn, error) { + l.once.Do(func() { + l.state <- nil + }) + return l.Listener.Accept() +} + +func (l *listener) AcceptKCP() (*kcp.UDPSession, error) { + l.once.Do(func() { + l.state <- nil + }) + return l.kcpListener.AcceptKCP() +} diff --git a/server/message.go b/server/message.go index d339566..0af8fa4 100644 --- a/server/message.go +++ b/server/message.go @@ -10,9 +10,6 @@ const ( // MessageTypePacket 数据包消息类型:该类型的数据将被发送到 ConnectionReceivePacketEvent 进行处理 MessageTypePacket MessageType = iota + 1 - // MessageTypeError 错误消息类型:根据不同的错误状态,将交由 Server 进行统一处理 - MessageTypeError - // MessageTypeTicker 定时器消息类型 MessageTypeTicker @@ -52,7 +49,6 @@ const ( var messageNames = map[MessageType]string{ MessageTypePacket: "MessageTypePacket", - MessageTypeError: "MessageTypeError", MessageTypeTicker: "MessageTypeTicker", MessageTypeShuntTicker: "MessageTypeShuntTicker", MessageTypeAsync: "MessageTypeAsync", @@ -67,22 +63,9 @@ var messageNames = map[MessageType]string{ MessageTypeShunt: "MessageTypeShunt", } -const ( - MessageErrorActionNone MessageErrorAction = iota + 1 // 错误消息类型操作:将不会被进行任何特殊处理,仅进行日志输出 - MessageErrorActionShutdown // 错误消息类型操作:当接收到该类型的操作时,服务器将执行 Server.shutdown 函数 -) - -var messageErrorActionNames = map[MessageErrorAction]string{ - MessageErrorActionNone: "None", - MessageErrorActionShutdown: "Shutdown", -} - type ( // MessageType 消息类型 MessageType byte - - // MessageErrorAction 错误消息类型操作 - MessageErrorAction byte ) // HasMessageType 检查是否存在指定的消息类型 @@ -90,10 +73,6 @@ func HasMessageType(mt MessageType) bool { return hash.Exist(messageNames, mt) } -func (slf MessageErrorAction) String() string { - return messageErrorActionNames[slf] -} - // Message 服务器消息 type Message struct { conn *Conn @@ -104,7 +83,6 @@ type Message struct { err error name string t MessageType - errAction MessageErrorAction marks []log.Field } @@ -118,7 +96,6 @@ func (slf *Message) reset() { slf.err = nil slf.name = "" slf.t = 0 - slf.errAction = 0 slf.marks = nil } @@ -219,12 +196,6 @@ func (slf *Message) castToSystemMessage(caller func(), mark ...log.Field) *Messa return slf } -// castToErrorMessage 将消息转换为错误消息 -func (slf *Message) castToErrorMessage(err error, action MessageErrorAction, mark ...log.Field) *Message { - slf.t, slf.err, slf.errAction, slf.marks = MessageTypeError, err, action, mark - return slf -} - // castToShuntMessage 将消息转换为分流消息 func (slf *Message) castToShuntMessage(conn *Conn, caller func(), mark ...log.Field) *Message { slf.t, slf.conn, slf.ordinaryHandler, slf.marks = MessageTypeShunt, conn, caller, mark diff --git a/server/multiple.go b/server/multiple.go index 0335a16..1670425 100644 --- a/server/multiple.go +++ b/server/multiple.go @@ -1,8 +1,6 @@ package server import ( - "github.com/kercylan98/minotaur/utils/log" - "github.com/kercylan98/minotaur/utils/network" "github.com/xtaci/kcp-go/v5" "math" "os" @@ -71,16 +69,7 @@ func (slf *MultipleServer) Run() { kcp.SystemTimedSched.Close() } - log.Info("Server", log.String(serverMultipleMark, "====================================================================")) - ip, _ := network.IP() - for _, server := range slf.servers { - log.Info("Server", log.String(serverMultipleMark, "RunningInfo"), - log.Any("network", server.network), - log.String("ip", ip.String()), - log.String("listen", server.addr), - ) - } - log.Info("Server", log.String(serverMultipleMark, "====================================================================")) + ShowServersInfo(serverMultipleMark, slf.servers...) systemSignal := make(chan os.Signal, 1) signal.Notify(systemSignal, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) diff --git a/server/network.go b/server/network.go index 62ce882..9f54d9d 100644 --- a/server/network.go +++ b/server/network.go @@ -1,6 +1,21 @@ package server -import "github.com/kercylan98/minotaur/utils/slice" +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/kercylan98/minotaur/server/internal/logger" + "github.com/kercylan98/minotaur/utils/hash" + "github.com/kercylan98/minotaur/utils/log" + "github.com/kercylan98/minotaur/utils/slice" + "github.com/kercylan98/minotaur/utils/super" + "github.com/panjf2000/gnet" + "github.com/xtaci/kcp-go/v5" + "google.golang.org/grpc" + "net" + "net/http" + "strings" + "time" +) type Network string @@ -23,12 +38,271 @@ const ( ) var ( - networks = []Network{ + networkNameMap map[string]struct{} + networks = []Network{ NetworkNone, NetworkTcp, NetworkTcp4, NetworkTcp6, NetworkUdp, NetworkUdp4, NetworkUdp6, NetworkUnix, NetworkHttp, NetworkWebsocket, NetworkKcp, NetworkGRPC, } ) +func init() { + networkNameMap = make(map[string]struct{}, len(networks)) + for _, network := range networks { + networkNameMap[string(network)] = struct{}{} + } +} + // GetNetworks 获取所有支持的网络模式 func GetNetworks() []Network { return slice.Copy(networks) } + +// check 检查网络模式是否支持 +func (n Network) check() { + if !hash.Exist(networkNameMap, string(n)) { + panic(fmt.Errorf("unsupported network mode: %s", n)) + } +} + +// preprocessing 服务器预处理 +func (n Network) preprocessing(srv *Server) { + switch n { + case NetworkNone: + case NetworkTcp: + case NetworkTcp4: + case NetworkTcp6: + case NetworkUdp: + case NetworkUdp4: + case NetworkUdp6: + case NetworkUnix: + case NetworkHttp: + srv.ginServer = gin.New() + srv.httpServer = &http.Server{ + Handler: srv.ginServer, + } + case NetworkWebsocket: + srv.websocketReadDeadline = DefaultWebsocketReadDeadline + case NetworkKcp: + case NetworkGRPC: + srv.grpcServer = grpc.NewServer() + } +} + +// adaptation 服务器适配 +func (n Network) adaptation(srv *Server) <-chan error { + state := make(chan error, 1) + switch n { + case NetworkNone: + srv.addr = "-" + case NetworkTcp: + n.gNetMode(state, srv) + case NetworkTcp4: + n.gNetMode(state, srv) + case NetworkTcp6: + n.gNetMode(state, srv) + case NetworkUdp: + n.gNetMode(state, srv) + case NetworkUdp4: + n.gNetMode(state, srv) + case NetworkUdp6: + n.gNetMode(state, srv) + case NetworkUnix: + n.gNetMode(state, srv) + case NetworkHttp: + n.httpMode(state, srv) + case NetworkWebsocket: + n.websocketMode(state, srv) + case NetworkKcp: + n.kcpMode(state, srv) + case NetworkGRPC: + n.grpcMode(state, srv) + } + return state +} + +// gNetMode gNet模式 +func (n Network) gNetMode(state chan<- error, srv *Server) { + srv.gServer = &gNet{Server: srv, state: state} + go func(srv *Server) { + if err := gnet.Serve(srv.gServer, fmt.Sprintf("%s://%s", srv.network, srv.addr), + gnet.WithLogger(new(logger.GNet)), + gnet.WithTicker(true), + gnet.WithMulticore(true), + ); err != nil { + srv.gServer.state <- err + } + }(srv) +} + +// grpcMode grpc模式 +func (n Network) grpcMode(state chan<- error, srv *Server) { + l, err := net.Listen(string(NetworkTcp), srv.addr) + if err != nil { + state <- err + return + } + lis := (&listener{srv: srv, Listener: l, state: state}).init() + go func(srv *Server, lis *listener) { + if err = srv.grpcServer.Serve(lis); err != nil { + lis.state <- err + } + }(srv, lis) +} + +// kcpMode kcp模式 +func (n Network) kcpMode(state chan<- error, srv *Server) { + l, err := kcp.ListenWithOptions(srv.addr, nil, 0, 0) + if err != nil { + state <- err + return + } + lis := (&listener{srv: srv, kcpListener: l, state: state}).init() + go func(lis *listener) { + for { + session, err := lis.AcceptKCP() + if err != nil { + continue + } + + conn := newKcpConn(lis.srv, session) + lis.srv.OnConnectionOpenedEvent(conn) + + go func(conn *Conn) { + defer func() { + if err := super.RecoverTransform(recover()); err != nil { + conn.Close(err) + } + }() + + buf := make([]byte, 4096) + for !conn.IsClosed() { + n, err := conn.kcp.Read(buf) + if err != nil { + if conn.IsClosed() { + break + } + panic(err) + } + lis.srv.PushPacketMessage(conn, 0, buf[:n]) + } + }(conn) + } + }(lis) + return +} + +// httpMode http模式 +func (n Network) httpMode(state chan<- error, srv *Server) { + srv.httpServer.Addr = srv.addr + l, err := net.Listen(string(NetworkTcp), srv.addr) + if err != nil { + state <- err + return + } + gin.SetMode(gin.ReleaseMode) + srv.ginServer.Use(func(c *gin.Context) { + t := time.Now() + c.Next() + log.Info("Server", log.String("type", "http"), + log.String("method", c.Request.Method), log.Int("status", c.Writer.Status()), + log.String("ip", c.ClientIP()), log.String("path", c.Request.URL.Path), + log.Duration("cost", time.Since(t))) + }) + go func(lis *listener) { + var err error + if len(lis.srv.certFile)+len(srv.keyFile) > 0 { + err = lis.srv.httpServer.ServeTLS(lis, lis.srv.certFile, lis.srv.keyFile) + } else { + err = lis.srv.httpServer.Serve(lis) + } + if err != nil { + lis.state <- err + } + }((&listener{srv: srv, Listener: l, state: state}).init()) +} + +// websocketMode websocket模式 +func (n Network) websocketMode(state chan<- error, srv *Server) { + l, err := net.Listen(string(NetworkTcp), srv.addr) + if err != nil { + state <- err + return + } + var pattern string + var index = strings.Index(srv.addr, "/") + if index == -1 { + pattern = "/" + } else { + pattern = srv.addr[index:] + srv.addr = srv.addr[:index] + } + if srv.websocketUpgrader == nil { + srv.websocketUpgrader = DefaultWebsocketUpgrader() + } + http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) { + ip := request.Header.Get("X-Real-IP") + ws, err := srv.websocketUpgrader.Upgrade(writer, request, nil) + if err != nil { + return + } + if srv.websocketConnInitializer != nil { + if err = srv.websocketConnInitializer(writer, request, ws); err != nil { + return + } + } + if len(ip) == 0 { + addr := ws.RemoteAddr().String() + if index := strings.LastIndex(addr, ":"); index != -1 { + ip = addr[0:index] + } + } + if srv.websocketCompression > 0 { + _ = ws.SetCompressionLevel(srv.websocketCompression) + } + ws.EnableWriteCompression(srv.websocketWriteCompression) + conn := newWebsocketConn(srv, ws, ip) + conn.SetData(wsRequestKey, request) + for k, v := range request.URL.Query() { + if len(v) == 1 { + conn.SetData(k, v[0]) + } else { + conn.SetData(k, v) + } + } + srv.OnConnectionOpenedEvent(conn) + + defer func() { + if err := super.RecoverTransform(recover()); err != nil { + conn.Close(err) + } + }() + for !conn.IsClosed() { + if srv.websocketReadDeadline > 0 { + if err := ws.SetReadDeadline(time.Now().Add(srv.websocketReadDeadline)); err != nil { + panic(err) + } + } + messageType, packet, readErr := ws.ReadMessage() + if readErr != nil { + if conn.IsClosed() { + break + } + panic(readErr) + } + if len(srv.supportMessageTypes) > 0 && !srv.supportMessageTypes[messageType] { + panic(ErrWebsocketIllegalMessageType) + } + srv.PushPacketMessage(conn, messageType, packet) + } + }) + go func(lis *listener) { + var err error + if len(lis.srv.certFile)+len(lis.srv.keyFile) > 0 { + err = http.ServeTLS(lis, nil, lis.srv.certFile, lis.srv.keyFile) + } else { + err = http.Serve(lis, nil) + } + if err != nil { + lis.state <- err + } + }((&listener{srv: srv, Listener: l, state: state}).init()) +} diff --git a/server/server.go b/server/server.go index 1b8daa8..94743ee 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/alphadose/haxmap" "github.com/gin-gonic/gin" "github.com/kercylan98/minotaur/server/internal/logger" "github.com/kercylan98/minotaur/utils/concurrent" @@ -16,12 +17,10 @@ import ( "github.com/panjf2000/gnet" "github.com/xtaci/kcp-go/v5" "google.golang.org/grpc" - "net" "net/http" "os" "os/signal" "runtime/debug" - "strings" "sync" "sync/atomic" "syscall" @@ -30,6 +29,7 @@ import ( // New 根据特定网络类型创建一个服务器 func New(network Network, options ...Option) *Server { + network.check() server := &Server{ runtime: &runtime{ packetWarnSize: DefaultPacketWarnSize, @@ -38,33 +38,20 @@ func New(network Network, options ...Option) *Server { }, option: &option{}, network: network, - online: concurrent.NewBalanceMap[string, *Conn](), + online: haxmap.New[string, *Conn](), closeChannel: make(chan struct{}, 1), systemSignal: make(chan os.Signal, 1), - ctx: context.Background(), dispatchers: make(map[string]*dispatcher), dispatcherMember: map[string]map[string]*Conn{}, currDispatcher: map[string]*dispatcher{}, } - server.ctx, server.cancel = context.WithCancel(server.ctx) + server.ctx, server.cancel = context.WithCancel(context.Background()) server.event = newEvent(server) - switch network { - case NetworkHttp: - server.ginServer = gin.New() - server.httpServer = &http.Server{ - Handler: server.ginServer, - } - case NetworkGRPC: - server.grpcServer = grpc.NewServer() - case NetworkWebsocket: - server.websocketReadDeadline = DefaultWebsocketReadDeadline - } - + network.preprocessing(server) for _, option := range options { option(server) } - if !server.disableAnts { if server.antsPoolSize <= 0 { server.antsPoolSize = DefaultAsyncPoolSize @@ -75,40 +62,51 @@ func New(network Network, options ...Option) *Server { panic(err) } } - server.option = nil return server } // Server 网络服务器 type Server struct { - *event // 事件 - *runtime // 运行时 - *option // 可选项 - ginServer *gin.Engine // HTTP模式下的路由器 - httpServer *http.Server // HTTP模式下的服务器 - grpcServer *grpc.Server // GRPC模式下的服务器 - gServer *gNet // TCP或UDP模式下的服务器 - multiple *MultipleServer // 多服务器模式下的服务器 - ants *ants.Pool // 协程池 - messagePool *concurrent.Pool[*Message] // 消息池 - ctx context.Context // 上下文 - cancel context.CancelFunc // 停止上下文 - online *concurrent.BalanceMap[string, *Conn] // 在线连接 - systemDispatcher *dispatcher // 系统消息分发器 - systemSignal chan os.Signal // 系统信号 - closeChannel chan struct{} // 关闭信号 - multipleRuntimeErrorChan chan error // 多服务器模式下的运行时错误 - dispatchers map[string]*dispatcher // 消息分发器集合 - dispatcherMember map[string]map[string]*Conn // 消息分发器包含的连接 - currDispatcher map[string]*dispatcher // 当前连接所处消息分发器 - dispatcherLock sync.RWMutex // 消息分发器锁 - isShutdown atomic.Bool // 是否已关闭 - messageCounter atomic.Int64 // 消息计数器 - addr string // 侦听地址 - network Network // 网络类型 - isRunning bool // 是否正在运行 - services []func() // 服务 + *event // 事件 + *runtime // 运行时 + *option // 可选项 + ginServer *gin.Engine // HTTP模式下的路由器 + httpServer *http.Server // HTTP模式下的服务器 + grpcServer *grpc.Server // GRPC模式下的服务器 + gServer *gNet // TCP或UDP模式下的服务器 + multiple *MultipleServer // 多服务器模式下的服务器 + ants *ants.Pool // 协程池 + messagePool *concurrent.Pool[*Message] // 消息池 + ctx context.Context // 上下文 + cancel context.CancelFunc // 停止上下文 + online *haxmap.Map[string, *Conn] // 在线连接 + systemDispatcher *dispatcher // 系统消息分发器 + systemSignal chan os.Signal // 系统信号 + closeChannel chan struct{} // 关闭信号 + multipleRuntimeErrorChan chan error // 多服务器模式下的运行时错误 + dispatchers map[string]*dispatcher // 消息分发器集合 + dispatcherMember map[string]map[string]*Conn // 消息分发器包含的连接 + currDispatcher map[string]*dispatcher // 当前连接所处消息分发器 + dispatcherLock sync.RWMutex // 消息分发器锁 + messageCounter atomic.Int64 // 消息计数器 + addr string // 侦听地址 + network Network // 网络类型 + closed uint32 // 服务器是否已关闭 + services []func() // 服务 +} + +// preCheckAndAdaptation 预检查及适配 +func (srv *Server) preCheckAndAdaptation(addr string) (startState <-chan error, err error) { + if srv.event == nil { + return nil, ErrConstructed + } + srv.addr = addr + if srv.multiple == nil && srv.network != NetworkKcp { + kcp.SystemTimedSched.Close() + } + + return srv.network.adaptation(srv), nil } // Run 使用特定地址运行服务器 @@ -123,245 +121,31 @@ type Server struct { // - server.NetworkWebsocket (addr:":8888/ws") // - server.NetworkKcp (addr:":8888") // - server.NetworkNone (addr:"") -func (slf *Server) Run(addr string) error { - if slf.network == NetworkNone { - addr = "-" +func (srv *Server) Run(addr string) (err error) { + var startState <-chan error + if startState, err = srv.preCheckAndAdaptation(addr); err != nil { + return err } - if slf.event == nil { - return ErrConstructed + onServicesInit(srv) + onMessageSystemInit(srv) + if srv.multiple == nil { + ShowServersInfo(serverMark, srv) } - onServicesInit(slf) - slf.event.check() - slf.addr = addr - slf.startMessageStatistics() - slf.systemDispatcher = generateDispatcher(slf.dispatcherBufferSize, serverSystemDispatcher, slf.dispatchMessage) - slf.messagePool = concurrent.NewPool[Message]( - func() *Message { - return &Message{} - }, - func(data *Message) { - data.reset() - }, - ) - if slf.network != NetworkHttp && slf.network != NetworkWebsocket && slf.network != NetworkGRPC { - slf.gServer = &gNet{Server: slf} + if err = <-startState; err != nil { + return err } - var protoAddr = fmt.Sprintf("%s://%s", slf.network, slf.addr) - go slf.systemDispatcher.start() + srv.OnStartFinishEvent() - switch slf.network { - case NetworkNone: - slf.isRunning = true - slf.OnStartBeforeEvent() - case NetworkGRPC: - listener, err := net.Listen(string(NetworkTcp), slf.addr) - if err != nil { - return err - } - go func() { - slf.isRunning = true - slf.OnStartBeforeEvent() - if err := slf.grpcServer.Serve(listener); err != nil { - slf.isRunning = false - slf.PushErrorMessage(err, MessageErrorActionShutdown) - } - }() - case NetworkTcp, NetworkTcp4, NetworkTcp6, NetworkUdp, NetworkUdp4, NetworkUdp6, NetworkUnix: - slf.isRunning = true - slf.OnStartBeforeEvent() - if err := gnet.Serve(slf.gServer, protoAddr, - gnet.WithLogger(new(logger.GNet)), - gnet.WithTicker(true), - gnet.WithMulticore(true), - ); err != nil { - slf.isRunning = false - slf.PushErrorMessage(err, MessageErrorActionShutdown) - } - case NetworkKcp: - listener, err := kcp.ListenWithOptions(slf.addr, nil, 0, 0) - if err != nil { - return err - } - slf.isRunning = true - slf.OnStartBeforeEvent() - for { - session, err := listener.AcceptKCP() - if err != nil { - continue - } - - conn := newKcpConn(slf, session) - slf.OnConnectionOpenedEvent(conn) - - go func(conn *Conn) { - defer func() { - if err := super.RecoverTransform(recover()); err != nil { - conn.Close(err) - } - }() - - buf := make([]byte, 4096) - for !conn.IsClosed() { - n, err := conn.kcp.Read(buf) - if err != nil { - if conn.IsClosed() { - break - } - panic(err) - } - slf.PushPacketMessage(conn, 0, buf[:n]) - } - }(conn) - } - case NetworkHttp: - go func() { - slf.isRunning = true - slf.OnStartBeforeEvent() - slf.httpServer.Addr = slf.addr - gin.SetMode(gin.ReleaseMode) - slf.ginServer.Use(func(c *gin.Context) { - t := time.Now() - c.Next() - log.Info("Server", log.String("type", "http"), - log.String("method", c.Request.Method), log.Int("status", c.Writer.Status()), - log.String("ip", c.ClientIP()), log.String("path", c.Request.URL.Path), - log.Duration("cost", time.Since(t))) - }) - if len(slf.certFile)+len(slf.keyFile) > 0 { - if err := slf.httpServer.ListenAndServeTLS(slf.certFile, slf.keyFile); err != nil { - slf.isRunning = false - slf.PushErrorMessage(err, MessageErrorActionShutdown) - } - } else { - if err := slf.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - slf.isRunning = false - slf.PushErrorMessage(err, MessageErrorActionShutdown) - } - } - - }() - case NetworkWebsocket: - var pattern string - var index = strings.Index(addr, "/") - if index == -1 { - pattern = "/" - } else { - pattern = addr[index:] - slf.addr = slf.addr[:index] - } - if slf.websocketUpgrader == nil { - slf.websocketUpgrader = DefaultWebsocketUpgrader() - } - http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) { - ip := request.Header.Get("X-Real-IP") - ws, err := slf.websocketUpgrader.Upgrade(writer, request, nil) - if err != nil { - return - } - if slf.websocketConnInitializer != nil { - if err = slf.websocketConnInitializer(writer, request, ws); err != nil { - return - } - } - if len(ip) == 0 { - addr := ws.RemoteAddr().String() - if index := strings.LastIndex(addr, ":"); index != -1 { - ip = addr[0:index] - } - } - if slf.websocketCompression > 0 { - _ = ws.SetCompressionLevel(slf.websocketCompression) - } - ws.EnableWriteCompression(slf.websocketWriteCompression) - conn := newWebsocketConn(slf, ws, ip) - conn.SetData(wsRequestKey, request) - for k, v := range request.URL.Query() { - if len(v) == 1 { - conn.SetData(k, v[0]) - } else { - conn.SetData(k, v) - } - } - slf.OnConnectionOpenedEvent(conn) - - defer func() { - if err := super.RecoverTransform(recover()); err != nil { - conn.Close(err) - } - }() - for !conn.IsClosed() { - if slf.websocketReadDeadline > 0 { - if err := ws.SetReadDeadline(time.Now().Add(slf.websocketReadDeadline)); err != nil { - panic(err) - } - } - messageType, packet, readErr := ws.ReadMessage() - if readErr != nil { - if conn.IsClosed() { - break - } - panic(readErr) - } - if len(slf.supportMessageTypes) > 0 && !slf.supportMessageTypes[messageType] { - panic(ErrWebsocketIllegalMessageType) - } - slf.PushPacketMessage(conn, messageType, packet) - } - }) - go func() { - slf.isRunning = true - slf.OnStartBeforeEvent() - if len(slf.certFile)+len(slf.keyFile) > 0 { - if err := http.ListenAndServeTLS(slf.addr, slf.certFile, slf.keyFile, nil); err != nil { - slf.isRunning = false - slf.PushErrorMessage(err, MessageErrorActionShutdown) - } - } else { - if err := http.ListenAndServe(slf.addr, nil); err != nil { - slf.isRunning = false - slf.PushErrorMessage(err, MessageErrorActionShutdown) - } - } - - }() - default: - return ErrCanNotSupportNetwork - } - - if slf.multiple == nil && slf.network != NetworkKcp { - kcp.SystemTimedSched.Close() - } - - if slf.multiple == nil { - ip, _ := network.IP() - log.Info("Server", log.String(serverMark, "====================================================================")) - log.Info("Server", log.String(serverMark, "RunningInfo"), - log.Any("network", slf.network), - log.String("ip", ip.String()), - log.String("listen", slf.addr), - ) - log.Info("Server", log.String(serverMark, "====================================================================")) - slf.OnStartFinishEvent() - time.Sleep(time.Second) - if !slf.isShutdown.Load() { - slf.OnMessageReadyEvent() - } - - signal.Notify(slf.systemSignal, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) + if srv.multiple == nil { + signal.Notify(srv.systemSignal, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) select { - case <-slf.systemSignal: - slf.shutdown(nil) + case <-srv.systemSignal: + srv.shutdown(nil) } select { - case <-slf.closeChannel: - close(slf.closeChannel) - } - } else { - slf.OnStartFinishEvent() - time.Sleep(time.Second) - if !slf.isShutdown.Load() { - slf.OnMessageReadyEvent() + case <-srv.closeChannel: + close(srv.closeChannel) } } @@ -369,36 +153,36 @@ func (slf *Server) Run(addr string) error { } // IsSocket 是否是 Socket 模式 -func (slf *Server) IsSocket() bool { - return slf.network == NetworkTcp || slf.network == NetworkTcp4 || slf.network == NetworkTcp6 || - slf.network == NetworkUdp || slf.network == NetworkUdp4 || slf.network == NetworkUdp6 || - slf.network == NetworkUnix || slf.network == NetworkKcp || slf.network == NetworkWebsocket +func (srv *Server) IsSocket() bool { + return srv.network == NetworkTcp || srv.network == NetworkTcp4 || srv.network == NetworkTcp6 || + srv.network == NetworkUdp || srv.network == NetworkUdp4 || srv.network == NetworkUdp6 || + srv.network == NetworkUnix || srv.network == NetworkKcp || srv.network == NetworkWebsocket } // RunNone 是 Run("") 的简写,仅适用于运行 NetworkNone 服务器 -func (slf *Server) RunNone() error { - return slf.Run(str.None) +func (srv *Server) RunNone() error { + return srv.Run(str.None) } // Context 获取服务器上下文 -func (slf *Server) Context() context.Context { - return slf.ctx +func (srv *Server) Context() context.Context { + return srv.ctx } // TimeoutContext 获取服务器超时上下文,context.WithTimeout 的简写 -func (slf *Server) TimeoutContext(timeout time.Duration) (context.Context, context.CancelFunc) { - return context.WithTimeout(slf.ctx, timeout) +func (srv *Server) TimeoutContext(timeout time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(srv.ctx, timeout) } // GetOnlineCount 获取在线人数 -func (slf *Server) GetOnlineCount() int { - return slf.online.Size() +func (srv *Server) GetOnlineCount() int { + return int(srv.online.Len()) } // GetOnlineBotCount 获取在线机器人数量 -func (slf *Server) GetOnlineBotCount() int { +func (srv *Server) GetOnlineBotCount() int { var count int - slf.online.Range(func(id string, conn *Conn) bool { + srv.online.ForEach(func(id string, conn *Conn) bool { if conn.IsBot() { count++ } @@ -408,201 +192,214 @@ func (slf *Server) GetOnlineBotCount() int { } // GetOnline 获取在线连接 -func (slf *Server) GetOnline(id string) *Conn { - return slf.online.Get(id) +func (srv *Server) GetOnline(id string) *Conn { + c, _ := srv.online.Get(id) + return c } // GetOnlineAll 获取所有在线连接 -func (slf *Server) GetOnlineAll() map[string]*Conn { - return slf.online.Map() +func (srv *Server) GetOnlineAll() map[string]*Conn { + var m = map[string]*Conn{} + srv.online.ForEach(func(id string, conn *Conn) bool { + m[id] = conn + return true + }) + return m } // IsOnline 是否在线 -func (slf *Server) IsOnline(id string) bool { - return slf.online.Exist(id) +func (srv *Server) IsOnline(id string) bool { + _, exist := srv.online.Get(id) + return exist } // CloseConn 关闭连接 -func (slf *Server) CloseConn(id string) { - if conn, exist := slf.online.GetExist(id); exist { +func (srv *Server) CloseConn(id string) { + if conn, exist := srv.online.Get(id); exist { conn.Close() } } // Ticker 获取服务器定时器 -func (slf *Server) Ticker() *timer.Ticker { - if slf.ticker == nil { +func (srv *Server) Ticker() *timer.Ticker { + if srv.ticker == nil { panic(ErrNoSupportTicker) } - return slf.ticker + return srv.ticker } // Shutdown 主动停止运行服务器 -func (slf *Server) Shutdown() { - slf.systemSignal <- syscall.SIGQUIT +func (srv *Server) Shutdown() { + srv.systemSignal <- syscall.SIGQUIT } // shutdown 停止运行服务器 -func (slf *Server) shutdown(err error) { +func (srv *Server) shutdown(err error) { + if !atomic.CompareAndSwapUint32(&srv.closed, 0, 1) { + return + } if err != nil { log.Error("Server", log.String("state", "shutdown"), log.Err(err)) } - slf.isShutdown.Store(true) - for slf.messageCounter.Load() > 0 { - log.Info("Server", log.Any("network", slf.network), log.String("listen", slf.addr), - log.String("action", "shutdown"), log.String("state", "waiting"), log.Int64("message", slf.messageCounter.Load())) + + for srv.messageCounter.Load() > 0 { + log.Info("Server", log.Any("network", srv.network), log.String("listen", srv.addr), + log.String("action", "shutdown"), log.String("state", "waiting"), log.Int64("message", srv.messageCounter.Load())) time.Sleep(time.Second) } - if slf.multiple == nil { - slf.OnStopEvent() + if srv.multiple == nil { + srv.OnStopEvent() } defer func() { - if slf.multipleRuntimeErrorChan != nil { - slf.multipleRuntimeErrorChan <- err + if srv.multipleRuntimeErrorChan != nil { + srv.multipleRuntimeErrorChan <- err } }() - slf.cancel() - if slf.gServer != nil && slf.isRunning { - if shutdownErr := gnet.Stop(context.Background(), fmt.Sprintf("%s://%s", slf.network, slf.addr)); err != nil { + srv.cancel() + if srv.gServer != nil { + if shutdownErr := gnet.Stop(context.Background(), fmt.Sprintf("%s://%s", srv.network, srv.addr)); err != nil { log.Error("Server", log.Err(shutdownErr)) } } - if slf.tickerPool != nil { - slf.tickerPool.Release() + if srv.tickerPool != nil { + srv.tickerPool.Release() } - if slf.ticker != nil { - slf.ticker.Release() + if srv.ticker != nil { + srv.ticker.Release() } - if slf.ants != nil { - slf.ants.Release() - slf.ants = nil + if srv.ants != nil { + srv.ants.Release() + srv.ants = nil } - slf.dispatcherLock.Lock() - for s, d := range slf.dispatchers { + srv.dispatcherLock.Lock() + for s, d := range srv.dispatchers { + srv.OnShuntChannelClosedEvent(d.name) d.close() - delete(slf.dispatchers, s) + delete(srv.dispatchers, s) } - slf.dispatcherLock.Unlock() - if slf.grpcServer != nil && slf.isRunning { - slf.grpcServer.GracefulStop() + srv.dispatcherLock.Unlock() + if srv.grpcServer != nil { + srv.grpcServer.GracefulStop() } - if slf.httpServer != nil && slf.isRunning { + if srv.httpServer != nil { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - if shutdownErr := slf.httpServer.Shutdown(ctx); shutdownErr != nil { + if shutdownErr := srv.httpServer.Shutdown(ctx); shutdownErr != nil { log.Error("Server", log.Err(shutdownErr)) } } if err != nil { - if slf.multiple != nil { - slf.multiple.RegExitEvent(func() { - log.Panic("Server", log.Any("network", slf.network), log.String("listen", slf.addr), + if srv.multiple != nil { + srv.multiple.RegExitEvent(func() { + log.Panic("Server", log.Any("network", srv.network), log.String("listen", srv.addr), log.String("action", "shutdown"), log.String("state", "exception"), log.Err(err)) }) - for i, server := range slf.multiple.servers { - if server.addr == slf.addr { - slf.multiple.servers = append(slf.multiple.servers[:i], slf.multiple.servers[i+1:]...) + for i, server := range srv.multiple.servers { + if server.addr == srv.addr { + srv.multiple.servers = append(srv.multiple.servers[:i], srv.multiple.servers[i+1:]...) break } } } else { - log.Panic("Server", log.Any("network", slf.network), log.String("listen", slf.addr), + log.Panic("Server", log.Any("network", srv.network), log.String("listen", srv.addr), log.String("action", "shutdown"), log.String("state", "exception"), log.Err(err)) } } else { - log.Info("Server", log.Any("network", slf.network), log.String("listen", slf.addr), + log.Info("Server", log.Any("network", srv.network), log.String("listen", srv.addr), log.String("action", "shutdown"), log.String("state", "normal")) } - slf.closeChannel <- struct{}{} + srv.closeChannel <- struct{}{} } // GRPCServer 当网络类型为 NetworkGRPC 时将被允许获取 grpc 服务器,否则将会发生 panic -func (slf *Server) GRPCServer() *grpc.Server { - if slf.grpcServer == nil { +func (srv *Server) GRPCServer() *grpc.Server { + if srv.grpcServer == nil { panic(ErrNetworkOnlySupportGRPC) } - return slf.grpcServer + return srv.grpcServer } // HttpRouter 当网络类型为 NetworkHttp 时将被允许获取路由器进行路由注册,否则将会发生 panic // - 通过该函数注册的路由将无法在服务器关闭时正常等待请求结束 // // Deprecated: 从 Minotaur 0.0.29 开始,由于设计原因已弃用,该函数将直接返回 *gin.Server 对象,导致无法正常的对请求结束时进行处理 -func (slf *Server) HttpRouter() gin.IRouter { - if slf.ginServer == nil { +func (srv *Server) HttpRouter() gin.IRouter { + if srv.ginServer == nil { panic(ErrNetworkOnlySupportHttp) } - return slf.ginServer + return srv.ginServer } // HttpServer 替代 HttpRouter 的函数,返回一个 *Http[*HttpContext] 对象 // - 通过该函数注册的路由将在服务器关闭时正常等待请求结束 // - 如果需要自行包装 Context 对象,可以使用 NewHttpHandleWrapper 方法 -func (slf *Server) HttpServer() *Http[*HttpContext] { - if slf.ginServer == nil { +func (srv *Server) HttpServer() *Http[*HttpContext] { + if srv.ginServer == nil { panic(ErrNetworkOnlySupportHttp) } - return NewHttpHandleWrapper(slf, func(ctx *gin.Context) *HttpContext { + return NewHttpHandleWrapper(srv, func(ctx *gin.Context) *HttpContext { return NewHttpContext(ctx) }) } // GetMessageCount 获取当前服务器中消息的数量 -func (slf *Server) GetMessageCount() int64 { - return slf.messageCounter.Load() +func (srv *Server) GetMessageCount() int64 { + return srv.messageCounter.Load() } // UseShunt 切换连接所使用的消息分流渠道,当分流渠道 name 不存在时将会创建一个新的分流渠道,否则将会加入已存在的分流渠道 // - 默认情况下,所有连接都使用系统通道进行消息分发,当指定消息分流渠道时,将会使用指定的消息分流渠道进行消息分发 // - 在使用 WithDisableAutomaticReleaseShunt 创建服务器后,必须始终在连接不再使用后主动通过 ReleaseShunt 释放消息分流渠道,否则将造成内存泄漏 -func (slf *Server) UseShunt(conn *Conn, name string) { - slf.dispatcherLock.Lock() - defer slf.dispatcherLock.Unlock() - d, exist := slf.dispatchers[name] +func (srv *Server) UseShunt(conn *Conn, name string) { + srv.dispatcherLock.Lock() + defer srv.dispatcherLock.Unlock() + d, exist := srv.dispatchers[name] if !exist { - d = generateDispatcher(slf.dispatcherBufferSize, name, slf.dispatchMessage) + d = generateDispatcher(srv.dispatcherBufferSize, name, srv.dispatchMessage) + srv.OnShuntChannelCreatedEvent(d.name) go d.start() - slf.dispatchers[name] = d + srv.dispatchers[name] = d } - curr, exist := slf.currDispatcher[conn.GetID()] + curr, exist := srv.currDispatcher[conn.GetID()] if exist { if curr.name == name { return } - delete(slf.dispatcherMember[curr.name], conn.GetID()) - if curr.name != serverSystemDispatcher && len(slf.dispatcherMember[curr.name]) == 0 { - delete(slf.dispatchers, curr.name) + delete(srv.dispatcherMember[curr.name], conn.GetID()) + if curr.name != serverSystemDispatcher && len(srv.dispatcherMember[curr.name]) == 0 { + delete(srv.dispatchers, curr.name) curr.transfer(d) + srv.OnShuntChannelClosedEvent(d.name) curr.close() } } - slf.currDispatcher[conn.GetID()] = d + srv.currDispatcher[conn.GetID()] = d - member, exist := slf.dispatcherMember[name] + member, exist := srv.dispatcherMember[name] if !exist { member = map[string]*Conn{} - slf.dispatcherMember[name] = member + srv.dispatcherMember[name] = member } member[conn.GetID()] = conn } // HasShunt 检查特定消息分流渠道是否存在 -func (slf *Server) HasShunt(name string) bool { - slf.dispatcherLock.RLock() - defer slf.dispatcherLock.RUnlock() - _, exist := slf.dispatchers[name] +func (srv *Server) HasShunt(name string) bool { + srv.dispatcherLock.RLock() + defer srv.dispatcherLock.RUnlock() + _, exist := srv.dispatchers[name] return exist } // GetConnCurrShunt 获取连接当前所使用的消息分流渠道 -func (slf *Server) GetConnCurrShunt(conn *Conn) string { - slf.dispatcherLock.RLock() - defer slf.dispatcherLock.RUnlock() - d, exist := slf.currDispatcher[conn.GetID()] +func (srv *Server) GetConnCurrShunt(conn *Conn) string { + srv.dispatcherLock.RLock() + defer srv.dispatcherLock.RUnlock() + d, exist := srv.currDispatcher[conn.GetID()] if exist { return d.name } @@ -610,56 +407,57 @@ func (slf *Server) GetConnCurrShunt(conn *Conn) string { } // GetShuntNum 获取消息分流渠道数量 -func (slf *Server) GetShuntNum() int { - slf.dispatcherLock.RLock() - defer slf.dispatcherLock.RUnlock() - return len(slf.dispatchers) +func (srv *Server) GetShuntNum() int { + srv.dispatcherLock.RLock() + defer srv.dispatcherLock.RUnlock() + return len(srv.dispatchers) } // getConnDispatcher 获取连接所使用的消息分发器 -func (slf *Server) getConnDispatcher(conn *Conn) *dispatcher { +func (srv *Server) getConnDispatcher(conn *Conn) *dispatcher { if conn == nil { - return slf.systemDispatcher + return srv.systemDispatcher } - slf.dispatcherLock.RLock() - defer slf.dispatcherLock.RUnlock() - d, exist := slf.currDispatcher[conn.GetID()] + srv.dispatcherLock.RLock() + defer srv.dispatcherLock.RUnlock() + d, exist := srv.currDispatcher[conn.GetID()] if exist { return d } - return slf.systemDispatcher + return srv.systemDispatcher } // ReleaseShunt 释放分流渠道中的连接,当分流渠道中不再存在连接时将会自动释放分流渠道 // - 在未使用 WithDisableAutomaticReleaseShunt 选项时,当连接关闭时将会自动释放分流渠道中连接的资源占用 // - 若执行过程中连接正在使用,将会切换至系统通道 -func (slf *Server) ReleaseShunt(conn *Conn) { - slf.releaseDispatcher(conn) +func (srv *Server) ReleaseShunt(conn *Conn) { + srv.releaseDispatcher(conn) } // releaseDispatcher 关闭消息分发器 -func (slf *Server) releaseDispatcher(conn *Conn) { +func (srv *Server) releaseDispatcher(conn *Conn) { if conn == nil { return } cid := conn.GetID() - slf.dispatcherLock.Lock() - defer slf.dispatcherLock.Unlock() - d, exist := slf.currDispatcher[cid] + srv.dispatcherLock.Lock() + defer srv.dispatcherLock.Unlock() + d, exist := srv.currDispatcher[cid] if exist && d.name != serverSystemDispatcher { - delete(slf.dispatcherMember[d.name], cid) - if len(slf.dispatcherMember[d.name]) == 0 { + delete(srv.dispatcherMember[d.name], cid) + if len(srv.dispatcherMember[d.name]) == 0 { + srv.OnShuntChannelClosedEvent(d.name) d.close() - delete(slf.dispatchers, d.name) + delete(srv.dispatchers, d.name) } - delete(slf.currDispatcher, cid) + delete(srv.currDispatcher, cid) } } // pushMessage 向服务器中写入特定类型的消息,需严格遵守消息属性要求 -func (slf *Server) pushMessage(message *Message) { - if !slf.OnMessageExecBeforeEvent(message) { - slf.messagePool.Release(message) +func (srv *Server) pushMessage(message *Message) { + if !srv.OnMessageExecBeforeEvent(message) { + srv.messagePool.Release(message) return } var dispatcher *dispatcher @@ -668,22 +466,22 @@ func (slf *Server) pushMessage(message *Message) { MessageTypeShuntTicker, MessageTypeShuntAsync, MessageTypeShuntAsyncCallback, MessageTypeUniqueShuntAsync, MessageTypeUniqueShuntAsyncCallback, MessageTypeShunt: - dispatcher = slf.getConnDispatcher(message.conn) - case MessageTypeSystem, MessageTypeAsync, MessageTypeUniqueAsync, MessageTypeAsyncCallback, MessageTypeUniqueAsyncCallback, MessageTypeError, MessageTypeTicker: - dispatcher = slf.systemDispatcher + dispatcher = srv.getConnDispatcher(message.conn) + case MessageTypeSystem, MessageTypeAsync, MessageTypeUniqueAsync, MessageTypeAsyncCallback, MessageTypeUniqueAsyncCallback, MessageTypeTicker: + dispatcher = srv.systemDispatcher } if dispatcher == nil { return } if (message.t == MessageTypeUniqueShuntAsync || message.t == MessageTypeUniqueAsync) && dispatcher.unique(message.name) { - slf.messagePool.Release(message) + srv.messagePool.Release(message) return } - slf.hitMessageStatistics() + srv.hitMessageStatistics() dispatcher.put(message) } -func (slf *Server) low(message *Message, present time.Time, expect time.Duration, messageReplace ...string) { +func (srv *Server) low(message *Message, present time.Time, expect time.Duration, messageReplace ...string) { cost := time.Since(present) if cost > expect { if len(messageReplace) > 0 { @@ -693,30 +491,30 @@ func (slf *Server) low(message *Message, present time.Time, expect time.Duration } var fields = make([]log.Field, 0, len(message.marks)+5) if message.conn != nil { - fields = append(fields, log.String("shunt", slf.GetConnCurrShunt(message.conn))) + fields = append(fields, log.String("shunt", srv.GetConnCurrShunt(message.conn))) } fields = append(fields, log.String("type", messageNames[message.t]), log.String("cost", cost.String()), log.String("message", message.String())) fields = append(fields, message.marks...) //fields = append(fields, log.Stack("stack")) log.Warn("ServerLowMessage", fields...) - slf.OnMessageLowExecEvent(message, cost) + srv.OnMessageLowExecEvent(message, cost) } } // dispatchMessage 消息分发 -func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) { +func (srv *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) { var ( ctx context.Context cancel context.CancelFunc ) - if slf.deadlockDetect > 0 { - ctx, cancel = context.WithTimeout(context.Background(), slf.deadlockDetect) + if srv.deadlockDetect > 0 { + ctx, cancel = context.WithTimeout(context.Background(), srv.deadlockDetect) go func(ctx context.Context, msg *Message) { select { case <-ctx.Done(): if err := ctx.Err(); errors.Is(err, context.DeadlineExceeded) { log.Warn("Server", log.String("MessageType", messageNames[msg.t]), log.String("Info", msg.String()), log.Any("SuspectedDeadlock", msg)) - slf.OnDeadlockDetectEvent(msg) + srv.OnDeadlockDetectEvent(msg) } } }(ctx, msg) @@ -730,17 +528,17 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) { stack := string(debug.Stack()) log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.String("Info", msg.String()), log.Any("error", err), log.String("stack", stack)) fmt.Println(stack) - slf.OnMessageErrorEvent(msg, err) + srv.OnMessageErrorEvent(msg, err) } if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback { dispatcher.antiUnique(msg.name) } - slf.low(msg, present, time.Millisecond*100) - slf.messageCounter.Add(-1) + srv.low(msg, present, time.Millisecond*100) + srv.messageCounter.Add(-1) - if !slf.isShutdown.Load() { - slf.messagePool.Release(msg) + if atomic.CompareAndSwapUint32(&srv.closed, 0, 0) { + srv.messagePool.Release(msg) } }(msg) } else { @@ -751,24 +549,15 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) { switch msg.t { case MessageTypePacket: - if !slf.OnConnectionPacketPreprocessEvent(msg.conn, msg.packet, func(newPacket []byte) { + if !srv.OnConnectionPacketPreprocessEvent(msg.conn, msg.packet, func(newPacket []byte) { msg.packet = newPacket }) { - slf.OnConnectionReceivePacketEvent(msg.conn, msg.packet) - } - case MessageTypeError: - switch msg.errAction { - case MessageErrorActionNone: - log.Panic("Server", log.Err(msg.err)) - case MessageErrorActionShutdown: - slf.shutdown(msg.err) - default: - log.Warn("Server", log.String("not support message error action", msg.errAction.String())) + srv.OnConnectionReceivePacketEvent(msg.conn, msg.packet) } case MessageTypeTicker, MessageTypeShuntTicker: msg.ordinaryHandler() case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync: - if err := slf.ants.Submit(func() { + if err := srv.ants.Submit(func() { defer func() { if err := super.RecoverTransform(recover()); err != nil { if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync { @@ -777,14 +566,14 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) { stack := string(debug.Stack()) log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack)) fmt.Println(stack) - slf.OnMessageErrorEvent(msg, err) + srv.OnMessageErrorEvent(msg, err) } super.Handle(cancel) - slf.low(msg, present, time.Second) - slf.messageCounter.Add(-1) + srv.low(msg, present, time.Second) + srv.messageCounter.Add(-1) - if !slf.isShutdown.Load() { - slf.messagePool.Release(msg) + if atomic.CompareAndSwapUint32(&srv.closed, 0, 0) { + srv.messagePool.Release(msg) } }() var err error @@ -794,17 +583,17 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) { if msg.errHandler != nil { if msg.conn == nil { if msg.t == MessageTypeUniqueAsync { - slf.PushUniqueAsyncCallbackMessage(msg.name, err, msg.errHandler) + srv.PushUniqueAsyncCallbackMessage(msg.name, err, msg.errHandler) return } - slf.PushAsyncCallbackMessage(err, msg.errHandler) + srv.PushAsyncCallbackMessage(err, msg.errHandler) return } if msg.t == MessageTypeUniqueShuntAsync { - slf.PushUniqueShuntAsyncCallbackMessage(msg.conn, msg.name, err, msg.errHandler) + srv.PushUniqueShuntAsyncCallbackMessage(msg.conn, msg.name, err, msg.errHandler) return } - slf.PushShuntAsyncCallbackMessage(msg.conn, err, msg.errHandler) + srv.PushShuntAsyncCallbackMessage(msg.conn, err, msg.errHandler) return } dispatcher.antiUnique(msg.name) @@ -826,8 +615,8 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) { // PushSystemMessage 向服务器中推送 MessageTypeSystem 消息 // - 系统消息仅包含一个可执行函数,将在系统分发器中执行 // - mark 为可选的日志标记,当发生异常时,将会在日志中进行体现 -func (slf *Server) PushSystemMessage(handler func(), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToSystemMessage(handler, mark...)) +func (srv *Server) PushSystemMessage(handler func(), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToSystemMessage(handler, mark...)) } // PushAsyncMessage 向服务器中推送 MessageTypeAsync 消息 @@ -835,34 +624,34 @@ func (slf *Server) PushSystemMessage(handler func(), mark ...log.Field) { // - callback 函数将在异步消息处理完成后进行调用,无论过程是否产生 err,都将被执行,允许为 nil // - 需要注意的是,为了避免并发问题,caller 函数请仅处理阻塞操作,其他操作应该在 callback 函数中进行 // - mark 为可选的日志标记,当发生异常时,将会在日志中进行体现 -func (slf *Server) PushAsyncMessage(caller func() error, callback func(err error), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToAsyncMessage(caller, callback, mark...)) +func (srv *Server) PushAsyncMessage(caller func() error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToAsyncMessage(caller, callback, mark...)) } // PushAsyncCallbackMessage 向服务器中推送 MessageTypeAsyncCallback 消息 // - 异步消息回调将会通过一个接收 error 的函数进行处理,该函数将在系统分发器中执行 // - mark 为可选的日志标记,当发生异常时,将会在日志中进行体现 -func (slf *Server) PushAsyncCallbackMessage(err error, callback func(err error), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToAsyncCallbackMessage(err, callback, mark...)) +func (srv *Server) PushAsyncCallbackMessage(err error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToAsyncCallbackMessage(err, callback, mark...)) } // PushShuntAsyncMessage 向特定分发器中推送 MessageTypeAsync 消息,消息执行与 MessageTypeAsync 一致 // - 需要注意的是,当未指定 UseShunt 时,将会通过 PushAsyncMessage 进行转发 // - mark 为可选的日志标记,当发生异常时,将会在日志中进行体现 -func (slf *Server) PushShuntAsyncMessage(conn *Conn, caller func() error, callback func(err error), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToShuntAsyncMessage(conn, caller, callback, mark...)) +func (srv *Server) PushShuntAsyncMessage(conn *Conn, caller func() error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToShuntAsyncMessage(conn, caller, callback, mark...)) } // PushShuntAsyncCallbackMessage 向特定分发器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致 // - 需要注意的是,当未指定 UseShunt 时,将会通过 PushAsyncCallbackMessage 进行转发 -func (slf *Server) PushShuntAsyncCallbackMessage(conn *Conn, err error, callback func(err error), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToShuntAsyncCallbackMessage(conn, err, callback, mark...)) +func (srv *Server) PushShuntAsyncCallbackMessage(conn *Conn, err error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToShuntAsyncCallbackMessage(conn, err, callback, mark...)) } // PushPacketMessage 向服务器中推送 MessageTypePacket 消息 // - 当存在 UseShunt 的选项时,将会根据选项中的 shuntMatcher 进行分发,否则将在系统分发器中处理消息 -func (slf *Server) PushPacketMessage(conn *Conn, wst int, packet []byte, mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToPacketMessage( +func (srv *Server) PushPacketMessage(conn *Conn, wst int, packet []byte, mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToPacketMessage( &Conn{wst: wst, connection: conn.connection}, packet, mark..., )) @@ -875,62 +664,53 @@ func (slf *Server) PushPacketMessage(conn *Conn, wst int, packet []byte, mark .. // // 定时消息执行不会有特殊的处理,仅标记为定时任务,也就是允许将各类函数通过该消息发送处理,但是并不建议 // - mark 为可选的日志标记,当发生异常时,将会在日志中进行体现 -func (slf *Server) PushTickerMessage(name string, caller func(), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToTickerMessage(name, caller, mark...)) +func (srv *Server) PushTickerMessage(name string, caller func(), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToTickerMessage(name, caller, mark...)) } // PushShuntTickerMessage 向特定分发器中推送 MessageTypeTicker 消息,消息执行与 MessageTypeTicker 一致 // - 需要注意的是,当未指定 UseShunt 时,将会通过 PushTickerMessage 进行转发 // - mark 为可选的日志标记,当发生异常时,将会在日志中进行体现 -func (slf *Server) PushShuntTickerMessage(conn *Conn, name string, caller func(), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToShuntTickerMessage(conn, name, caller, mark...)) +func (srv *Server) PushShuntTickerMessage(conn *Conn, name string, caller func(), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToShuntTickerMessage(conn, name, caller, mark...)) } // PushUniqueAsyncMessage 向服务器中推送 MessageTypeAsync 消息,消息执行与 MessageTypeAsync 一致 // - 不同的是当上一个相同的 unique 消息未执行完成时,将会忽略该消息 -func (slf *Server) PushUniqueAsyncMessage(unique string, caller func() error, callback func(err error), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToUniqueAsyncMessage(unique, caller, callback, mark...)) +func (srv *Server) PushUniqueAsyncMessage(unique string, caller func() error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToUniqueAsyncMessage(unique, caller, callback, mark...)) } // PushUniqueAsyncCallbackMessage 向服务器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致 -func (slf *Server) PushUniqueAsyncCallbackMessage(unique string, err error, callback func(err error), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToUniqueAsyncCallbackMessage(unique, err, callback, mark...)) +func (srv *Server) PushUniqueAsyncCallbackMessage(unique string, err error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToUniqueAsyncCallbackMessage(unique, err, callback, mark...)) } // PushUniqueShuntAsyncMessage 向特定分发器中推送 MessageTypeAsync 消息,消息执行与 MessageTypeAsync 一致 // - 需要注意的是,当未指定 UseShunt 时,将会通过系统分流渠道进行转发 // - 不同的是当上一个相同的 unique 消息未执行完成时,将会忽略该消息 -func (slf *Server) PushUniqueShuntAsyncMessage(conn *Conn, unique string, caller func() error, callback func(err error), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToUniqueShuntAsyncMessage(conn, unique, caller, callback, mark...)) +func (srv *Server) PushUniqueShuntAsyncMessage(conn *Conn, unique string, caller func() error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToUniqueShuntAsyncMessage(conn, unique, caller, callback, mark...)) } // PushUniqueShuntAsyncCallbackMessage 向特定分发器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致 // - 需要注意的是,当未指定 UseShunt 时,将会通过系统分流渠道进行转发 -func (slf *Server) PushUniqueShuntAsyncCallbackMessage(conn *Conn, unique string, err error, callback func(err error), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToUniqueShuntAsyncCallbackMessage(conn, unique, err, callback, mark...)) -} - -// PushErrorMessage 向服务器中推送 MessageTypeError 消息 -// - 通过该函数推送错误消息,当消息触发时将在系统分发器中处理消息 -// - 参数 errAction 用于指定错误消息的处理方式,可选值为 MessageErrorActionNone 和 MessageErrorActionShutdown -// - 参数 errAction 为 MessageErrorActionShutdown 时,将会停止服务器的运行 -// - mark 为可选的日志标记,当发生异常时,将会在日志中进行体现 -func (slf *Server) PushErrorMessage(err error, errAction MessageErrorAction, mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToErrorMessage(err, errAction, mark...)) +func (srv *Server) PushUniqueShuntAsyncCallbackMessage(conn *Conn, unique string, err error, callback func(err error), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToUniqueShuntAsyncCallbackMessage(conn, unique, err, callback, mark...)) } // PushShuntMessage 向特定分发器中推送 MessageTypeShunt 消息,消息执行与 MessageTypeSystem 一致,不同的是将会在特定分发器中执行 -func (slf *Server) PushShuntMessage(conn *Conn, caller func(), mark ...log.Field) { - slf.pushMessage(slf.messagePool.Get().castToShuntMessage(conn, caller, mark...)) +func (srv *Server) PushShuntMessage(conn *Conn, caller func(), mark ...log.Field) { + srv.pushMessage(srv.messagePool.Get().castToShuntMessage(conn, caller, mark...)) } // startMessageStatistics 开始消息统计 -func (slf *Server) startMessageStatistics() { - if !slf.HasMessageStatistics() { +func (srv *Server) startMessageStatistics() { + if !srv.HasMessageStatistics() { return } - slf.runtime.messageStatistics = append(slf.runtime.messageStatistics, new(atomic.Int64)) - ticker := time.NewTicker(slf.runtime.messageStatisticsDuration) + srv.runtime.messageStatistics = append(srv.runtime.messageStatistics, new(atomic.Int64)) + ticker := time.NewTicker(srv.runtime.messageStatisticsDuration) go func(ctx context.Context, ticker *time.Ticker, r *runtime) { defer ticker.Stop() for { @@ -946,56 +726,95 @@ func (slf *Server) startMessageStatistics() { return } } - }(slf.ctx, ticker, slf.runtime) + }(srv.ctx, ticker, srv.runtime) } // hitMessageStatistics 命中消息统计 -func (slf *Server) hitMessageStatistics() { - slf.messageCounter.Add(1) - if !slf.HasMessageStatistics() { +func (srv *Server) hitMessageStatistics() { + srv.messageCounter.Add(1) + if !srv.HasMessageStatistics() { return } - slf.runtime.messageStatisticsLock.RLock() - slf.runtime.messageStatistics[0].Add(1) - slf.runtime.messageStatisticsLock.RUnlock() + srv.runtime.messageStatisticsLock.RLock() + srv.runtime.messageStatistics[0].Add(1) + srv.runtime.messageStatisticsLock.RUnlock() } // GetDurationMessageCount 获取当前 WithMessageStatistics 设置的 duration 期间的消息量 -func (slf *Server) GetDurationMessageCount() int64 { - return slf.GetDurationMessageCountByOffset(0) +func (srv *Server) GetDurationMessageCount() int64 { + return srv.GetDurationMessageCountByOffset(0) } // GetDurationMessageCountByOffset 获取特定偏移次数的 WithMessageStatistics 设置的 duration 期间的消息量 // - 该值小于 0 时,将与 GetDurationMessageCount 无异,否则将返回 +n 个期间的消息量,例如 duration 为 1 分钟,limit 为 10,那么 offset 为 1 的情况下,获取的则是上一分钟消息量 -func (slf *Server) GetDurationMessageCountByOffset(offset int) int64 { - if !slf.HasMessageStatistics() { +func (srv *Server) GetDurationMessageCountByOffset(offset int) int64 { + if !srv.HasMessageStatistics() { return 0 } - slf.runtime.messageStatisticsLock.Lock() - if offset >= len(slf.runtime.messageStatistics)-1 { - slf.runtime.messageStatisticsLock.Unlock() + srv.runtime.messageStatisticsLock.Lock() + if offset >= len(srv.runtime.messageStatistics)-1 { + srv.runtime.messageStatisticsLock.Unlock() return 0 } - v := slf.runtime.messageStatistics[offset].Load() - slf.runtime.messageStatisticsLock.Unlock() + v := srv.runtime.messageStatistics[offset].Load() + srv.runtime.messageStatisticsLock.Unlock() return v } // GetAllDurationMessageCount 获取所有 WithMessageStatistics 设置的 duration 期间的消息量 -func (slf *Server) GetAllDurationMessageCount() []int64 { - if !slf.HasMessageStatistics() { +func (srv *Server) GetAllDurationMessageCount() []int64 { + if !srv.HasMessageStatistics() { return nil } - slf.runtime.messageStatisticsLock.Lock() - var vs = make([]int64, len(slf.runtime.messageStatistics)) - for i, statistic := range slf.runtime.messageStatistics { + srv.runtime.messageStatisticsLock.Lock() + var vs = make([]int64, len(srv.runtime.messageStatistics)) + for i, statistic := range srv.runtime.messageStatistics { vs[i] = statistic.Load() } - slf.runtime.messageStatisticsLock.Unlock() + srv.runtime.messageStatisticsLock.Unlock() return vs } // HasMessageStatistics 是否了开启消息统计 -func (slf *Server) HasMessageStatistics() bool { - return slf.runtime.messageStatisticsLock != nil +func (srv *Server) HasMessageStatistics() bool { + return srv.runtime.messageStatisticsLock != nil +} + +// ShowServersInfo 显示服务器信息 +func ShowServersInfo(mark string, servers ...*Server) { + var serverInfos = make([]func(), 0, len(servers)) + var ip, _ = network.IP() + for _, srv := range servers { + serverInfos = append(serverInfos, func() { + log.Info("Server", log.String(mark, "RunningInfo"), log.Any("network", srv.network), log.String("ip", ip.String()), log.String("listen", srv.addr)) + }) + } + log.Info("Server", log.String(mark, "====================================================================")) + for _, info := range serverInfos { + info() + } + log.Info("Server", log.String(mark, "====================================================================")) +} + +// onServicesInit 服务初始化 +func onServicesInit(srv *Server) { + for _, service := range srv.services { + service() + } +} + +// onMessageSystemInit 消息系统初始化 +func onMessageSystemInit(srv *Server) { + srv.messagePool = concurrent.NewPool[Message]( + func() *Message { + return &Message{} + }, + func(data *Message) { + data.reset() + }, + ) + srv.startMessageStatistics() + srv.systemDispatcher = generateDispatcher(srv.dispatcherBufferSize, serverSystemDispatcher, srv.dispatchMessage) + go srv.systemDispatcher.start() + srv.OnMessageReadyEvent() } diff --git a/server/server_test.go b/server/server_test.go index 1cc8aef..006ae7a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -11,6 +11,32 @@ import ( func TestNew(t *testing.T) { srv := server.New(server.NetworkWebsocket, server.WithPProf()) + srv.RegStartBeforeEvent(func(srv *server.Server) { + fmt.Println("启动前") + }) + srv.RegStartFinishEvent(func(srv *server.Server) { + fmt.Println("启动完成") + }) + srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) { + fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount()) + }) + + srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { + conn.Write(packet) + }) + if err := srv.Run(":9999"); err != nil { + panic(err) + } +} + +func TestNew2(t *testing.T) { + srv := server.New(server.NetworkWebsocket, server.WithPProf()) + srv.RegStartBeforeEvent(func(srv *server.Server) { + fmt.Println("启动前") + }) + srv.RegStartFinishEvent(func(srv *server.Server) { + fmt.Println("启动完成") + }) srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) { fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount()) }) diff --git a/server/service.go b/server/service.go index 003d650..f61402e 100644 --- a/server/service.go +++ b/server/service.go @@ -28,10 +28,3 @@ func BindService(srv *Server, services ...Service) { }) } } - -// onServicesInit 服务初始化 -func onServicesInit(srv *Server) { - for _, service := range srv.services { - service() - } -}