diff --git a/server/server.go b/server/server.go index aefdc09..d522ce2 100644 --- a/server/server.go +++ b/server/server.go @@ -8,6 +8,7 @@ import ( "github.com/kercylan98/minotaur/utils/log" "github.com/kercylan98/minotaur/utils/synchronization" "github.com/panjf2000/gnet" + "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "go.uber.org/zap" "google.golang.org/grpc" @@ -15,7 +16,9 @@ import ( "net/http" "os" "os/signal" + "runtime/debug" "strings" + "sync/atomic" "syscall" "time" ) @@ -23,10 +26,11 @@ import ( // New 根据特定网络类型创建一个服务器 func New(network Network, options ...Option) *Server { server := &Server{ - event: &event{}, - network: network, - options: options, - core: 1, + event: &event{}, + network: network, + options: options, + core: 1, + closeChannel: make(chan struct{}), } server.event.Server = server @@ -47,14 +51,16 @@ func New(network Network, options ...Option) *Server { // Server 网络服务器 type Server struct { *event - network Network // 网络类型 - addr string // 侦听地址 - options []Option // 选项 - ginServer *gin.Engine // HTTP模式下的路由器 - httpServer *http.Server // HTTP模式下的服务器 - grpcServer *grpc.Server // GRPC模式下的服务器 - supportMessageTypes map[int]bool // websocket模式下支持的消息类型 - certFile, keyFile string // TLS文件 + network Network // 网络类型 + addr string // 侦听地址 + options []Option // 选项 + ginServer *gin.Engine // HTTP模式下的路由器 + httpServer *http.Server // HTTP模式下的服务器 + grpcServer *grpc.Server // GRPC模式下的服务器 + supportMessageTypes map[int]bool // websocket模式下支持的消息类型 + certFile, keyFile string // TLS文件 + isShutdown atomic.Bool // 是否已关闭 + closeChannel chan struct{} // 关闭信号 gServer *gNet // TCP或UDP模式下的服务器 messagePool *synchronization.Pool[*message] // 消息池 @@ -93,7 +99,8 @@ func (slf *Server) Run(addr string) error { slf.messagePool = synchronization.NewPool[*message](slf.messagePoolSize, func() *message { return &message{} - }, func(data *message) { + }, + func(data *message) { data.t = 0 data.attrs = nil }, @@ -123,14 +130,14 @@ func (slf *Server) Run(addr string) error { go func() { slf.OnStartBeforeEvent() if err := slf.grpcServer.Serve(listener); err != nil { - slf.PushMessage(MessageTypeError, err, MessageErrorActionShutdown) + slf.PushMessage(MessageTypeError, errors.WithMessage(err, string(debug.Stack())), MessageErrorActionShutdown) } }() case NetworkTCP, NetworkTCP4, NetworkTCP6, NetworkUdp, NetworkUdp4, NetworkUdp6, NetworkUnix: go connectionInitHandle(func() { slf.OnStartBeforeEvent() if err := gnet.Serve(slf.gServer, protoAddr); err != nil { - slf.PushMessage(MessageTypeError, err, MessageErrorActionShutdown) + slf.PushMessage(MessageTypeError, errors.WithMessage(err, string(debug.Stack())), MessageErrorActionShutdown) } }) case NetworkKcp: @@ -178,88 +185,90 @@ func (slf *Server) Run(addr string) error { slf.httpServer.Addr = slf.addr if len(slf.certFile)+len(slf.keyFile) > 0 { if err := slf.httpServer.ListenAndServeTLS(slf.certFile, slf.keyFile); err != nil { - slf.PushMessage(MessageTypeError, err, MessageErrorActionShutdown) + slf.PushMessage(MessageTypeError, errors.WithMessage(err, string(debug.Stack())), MessageErrorActionShutdown) } } else { if err := slf.httpServer.ListenAndServe(); err != nil { - slf.PushMessage(MessageTypeError, err, MessageErrorActionShutdown) + slf.PushMessage(MessageTypeError, errors.WithMessage(err, string(debug.Stack())), MessageErrorActionShutdown) } } }() case NetworkWebsocket: - go connectionInitHandle(nil) - var pattern string - var index = strings.Index(addr, "/") - if index == -1 { - pattern = "/" - } else { - pattern = addr[index:] - } - var upgrade = websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - CheckOrigin: func(r *http.Request) bool { - return true - }, - } - http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) { - ip := request.Header.Get("X-Real-IP") - ws, err := upgrade.Upgrade(writer, request, nil) - if err != nil { - return - } - if len(ip) == 0 { - addr := ws.RemoteAddr().String() - if index := strings.LastIndex(addr, ":"); index != -1 { - ip = addr[0:index] - } - } - - conn := newWebsocketConn(ws, ip) - for k, v := range request.URL.Query() { - if len(v) == 1 { - conn.SetData(k, v) - } else { - conn.SetData(k, v) - } - } - slf.OnConnectionOpenedEvent(conn) - - defer func() { - if err := recover(); err != nil { - conn.Close() - slf.OnConnectionClosedEvent(conn) - } - }() - - for { - if err := ws.SetReadDeadline(time.Now().Add(time.Second * 30)); err != nil { - panic(err) - } - messageType, packet, err := ws.ReadMessage() - if err != nil { - panic(err) - } - if !slf.supportMessageTypes[messageType] { - panic(ErrWebsocketIllegalMessageType) - } - slf.PushMessage(MessageTypePacket, conn, packet, messageType) - } - }) - go func() { - slf.OnStartBeforeEvent() - if len(slf.certFile)+len(slf.keyFile) > 0 { - if err := http.ListenAndServeTLS(slf.addr, slf.certFile, slf.keyFile, nil); err != nil { - slf.PushMessage(MessageTypeError, err, MessageErrorActionShutdown) - } + go connectionInitHandle(func() { + var pattern string + var index = strings.Index(addr, "/") + if index == -1 { + pattern = "/" } else { - if err := http.ListenAndServe(slf.addr, nil); err != nil { - slf.PushMessage(MessageTypeError, err, MessageErrorActionShutdown) - } + pattern = addr[index:] + slf.addr = slf.addr[:index] } + var upgrade = websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) { + ip := request.Header.Get("X-Real-IP") + ws, err := upgrade.Upgrade(writer, request, nil) + if err != nil { + return + } + if len(ip) == 0 { + addr := ws.RemoteAddr().String() + if index := strings.LastIndex(addr, ":"); index != -1 { + ip = addr[0:index] + } + } - }() + conn := newWebsocketConn(ws, ip) + for k, v := range request.URL.Query() { + if len(v) == 1 { + conn.SetData(k, v) + } else { + conn.SetData(k, v) + } + } + slf.OnConnectionOpenedEvent(conn) + + defer func() { + if err := recover(); err != nil { + conn.Close() + slf.OnConnectionClosedEvent(conn) + } + }() + + for { + if err := ws.SetReadDeadline(time.Now().Add(time.Second * 30)); err != nil { + panic(err) + } + messageType, packet, err := ws.ReadMessage() + if err != nil { + panic(err) + } + if !slf.supportMessageTypes[messageType] { + panic(ErrWebsocketIllegalMessageType) + } + slf.PushMessage(MessageTypePacket, conn, packet, messageType) + } + }) + go func() { + slf.OnStartBeforeEvent() + if len(slf.certFile)+len(slf.keyFile) > 0 { + if err := http.ListenAndServeTLS(slf.addr, slf.certFile, slf.keyFile, nil); err != nil { + slf.PushMessage(MessageTypeError, errors.WithMessage(err, string(debug.Stack())), MessageErrorActionShutdown) + } + } else { + if err := http.ListenAndServe(slf.addr, nil); err != nil { + slf.PushMessage(MessageTypeError, errors.WithMessage(err, string(debug.Stack())), MessageErrorActionShutdown) + } + } + + }() + }) default: return ErrCanNotSupportNetwork } @@ -278,6 +287,9 @@ func (slf *Server) Run(addr string) error { select { case <-systemSignal: slf.Shutdown(nil) + case <-slf.closeChannel: + close(slf.closeChannel) + break } } else { slf.OnStartFinishEvent() @@ -298,6 +310,7 @@ func (slf *Server) IsDev() bool { // Shutdown 停止运行服务器 func (slf *Server) Shutdown(err error) { + slf.isShutdown.Store(true) if slf.initMessageChannel { if slf.gServer != nil { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -324,6 +337,7 @@ func (slf *Server) Shutdown(err error) { if err != nil { log.Error("Server", zap.Any("network", slf.network), zap.String("listen", slf.addr), zap.String("action", "shutdown"), zap.String("state", "exception"), zap.Error(err)) + slf.closeChannel <- struct{}{} } else { log.Info("Server", zap.Any("network", slf.network), zap.String("listen", slf.addr), zap.String("action", "shutdown"), zap.String("state", "normal")) @@ -356,7 +370,9 @@ func (slf *Server) PushMessage(messageType MessageType, attrs ...any) { // dispatchMessage 消息分发 func (slf *Server) dispatchMessage(msg *message) { defer func() { - slf.messagePool.Release(msg) + if !slf.isShutdown.Load() { + slf.messagePool.Release(msg) + } if err := recover(); err != nil { log.Error("Server", zap.String("MessageType", messageNames[msg.t]), zap.Any("MessageAttrs", msg.attrs), zap.Any("error", err)) } @@ -377,6 +393,7 @@ func (slf *Server) dispatchMessage(msg *message) { log.Error("Server", zap.Error(err)) case MessageErrorActionShutdown: slf.Shutdown(err) + fmt.Println(err) default: log.Warn("Server", zap.String("not support message error action", action.String())) }