diff --git a/server/gnet.go b/server/gnet.go index 7d926a0..0f8069c 100644 --- a/server/gnet.go +++ b/server/gnet.go @@ -21,7 +21,6 @@ func (slf *gNet) OnShutdown(server gnet.Server) { slf.connections.Delete(k) } slf.connections = nil - return } func (slf *gNet) OnOpened(c gnet.Conn) (out []byte, action gnet.Action) { diff --git a/server/server.go b/server/server.go index 2e6737c..9d6e21d 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -22,12 +23,19 @@ import ( // New 根据特定网络类型创建一个服务器 func New(network Network, options ...Option) *Server { server := &Server{ + event: &event{}, network: network, + options: options, } - server.event = &event{Server: server} + server.event.Server = server if network == NetworkHttp { - server.httpServer = gin.New() + server.ginServer = gin.New() + server.httpServer = &http.Server{ + Handler: server.ginServer, + } + } else if network == NetworkGRPC { + server.grpcServer = grpc.NewServer() } for _, option := range options { option(server) @@ -38,11 +46,14 @@ func New(network Network, options ...Option) *Server { // Server 网络服务器 type Server struct { *event - network Network // 网络类型 - addr string // 侦听地址 + network Network // 网络类型 + addr string // 侦听地址 + options []Option // 选项 + ginServer *gin.Engine // HTTP模式下的路由器 + httpServer *http.Server // HTTP模式下的服务器 + grpcServer *grpc.Server // GRPC模式下的服务器 + connections *synchronization.Map[string, *Conn] // 所有在线的连接 - httpServer *gin.Engine // HTTP模式下的服务器 - grpcServer *grpc.Server // GRPC模式下的服务器 gServer *gNet // TCP或UDP模式下的服务器 messagePool *synchronization.Pool[*message] // 消息池 messagePoolSize int // 消息池大小 @@ -103,7 +114,6 @@ func (slf *Server) Run(addr string) error { if err != nil { return err } - slf.grpcServer = grpc.NewServer() go func() { slf.OnStartBeforeEvent() if err := slf.grpcServer.Serve(listener); err != nil { @@ -159,7 +169,8 @@ func (slf *Server) Run(addr string) error { } go func() { slf.OnStartBeforeEvent() - if err := slf.httpServer.Run(addr); err != nil { + slf.httpServer.Addr = slf.addr + if err := slf.httpServer.ListenAndServe(); err != nil { slf.PushMessage(MessageTypeError, err, MessageErrorActionShutdown) } }() @@ -258,14 +269,32 @@ func (slf *Server) IsDev() bool { // Shutdown 停止运行服务器 func (slf *Server) Shutdown(err error) { - if slf.connections != nil { + if slf.initMessageChannel { + if slf.gServer != nil { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if shutdownErr := gnet.Stop(ctx, fmt.Sprintf("%s://%s", slf.network, slf.addr)); shutdownErr != nil { + log.Error("Server", zap.Error(shutdownErr)) + } + } slf.connections.Range(func(connId string, conn *Conn) { conn.Close() }) - } - if slf.initMessageChannel { close(slf.messageChannel) + slf.messagePool.Close() + slf.initMessageChannel = false } + if slf.grpcServer != nil { + slf.grpcServer.Stop() + } + if slf.httpServer != nil { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if shutdownErr := slf.httpServer.Shutdown(ctx); shutdownErr != nil { + log.Error("Server", zap.Error(shutdownErr)) + } + } + if err != nil { log.Error("Server", zap.Any("network", slf.network), zap.String("listen", slf.addr), zap.String("action", "shutdown"), zap.String("state", "exception"), zap.Error(err)) @@ -277,10 +306,10 @@ func (slf *Server) Shutdown(err error) { // HttpRouter 当网络类型为 NetworkHttp 时将被允许获取路由器进行路由注册,否则将会发生 panic func (slf *Server) HttpRouter() gin.IRouter { - if slf.httpServer == nil { + if slf.ginServer == nil { panic(ErrNetworkOnlySupportHttp) } - return slf.httpServer + return slf.ginServer } // PushMessage 向服务器中写入特定类型的消息,需严格遵守消息属性要求 diff --git a/utils/synchronization/pool.go b/utils/synchronization/pool.go index 7bbbf03..21adaf4 100644 --- a/utils/synchronization/pool.go +++ b/utils/synchronization/pool.go @@ -48,6 +48,15 @@ func (slf *Pool[T]) Release(data T) { slf.put(data) } +func (slf *Pool[T]) Close() { + slf.mutex.Lock() + slf.buffers = nil + slf.bufferSize = 0 + slf.generator = nil + slf.releaser = nil + slf.mutex.Unlock() +} + func (slf *Pool[T]) put(data T) { slf.mutex.Lock() if len(slf.buffers) > slf.bufferSize {