diff --git a/go.mod b/go.mod index 14aa3a8..5be7bc6 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/smartystreets/goconvey v1.8.1 github.com/sony/sonyflake v1.2.0 github.com/spf13/cobra v1.7.0 + github.com/stretchr/testify v1.8.3 github.com/tealeg/xlsx v1.0.5 github.com/tidwall/gjson v1.16.0 github.com/xtaci/kcp-go/v5 v5.6.3 @@ -29,6 +30,7 @@ require ( require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -48,6 +50,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/smarty/assertions v1.15.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/templexxx/cpu v0.1.0 // indirect diff --git a/server/client/client.go b/server/client/client.go index 8a83afa..af03fba 100644 --- a/server/client/client.go +++ b/server/client/client.go @@ -62,7 +62,7 @@ func (slf *Client) Run(block ...bool) error { return err } slf.closed = false - slf.pool = concurrent.NewPool[*Packet](10*1024, func() *Packet { + slf.pool = concurrent.NewPool[Packet](func() *Packet { return new(Packet) }, func(data *Packet) { data.wst = 0 @@ -100,7 +100,6 @@ func (slf *Client) Close(err ...error) { slf.closed = true slf.core.Close() slf.loop.Close() - slf.pool.Close() slf.mutex.Unlock() if len(err) > 0 { slf.OnConnectionClosedEvent(slf, err[0]) diff --git a/server/conn.go b/server/conn.go index cdec322..2ee1c09 100644 --- a/server/conn.go +++ b/server/conn.go @@ -286,7 +286,7 @@ func (slf *Conn) init() { })) } } - slf.pool = concurrent.NewPool[*connPacket](10*1024, + slf.pool = concurrent.NewPool[connPacket]( func() *connPacket { return &connPacket{} }, func(data *connPacket) { @@ -360,7 +360,6 @@ func (slf *Conn) Close(err ...error) { slf.ticker.Release() } slf.server.releaseDispatcher(slf) - slf.pool.Close() slf.loop.Close() slf.mu.Unlock() if len(err) > 0 { diff --git a/server/constants.go b/server/constants.go index 55ff658..99b7a91 100644 --- a/server/constants.go +++ b/server/constants.go @@ -11,7 +11,6 @@ const ( ) const ( - DefaultMessageBufferSize = 1024 DefaultAsyncPoolSize = 256 DefaultWebsocketReadDeadline = 30 * time.Second DefaultPacketWarnSize = 1024 * 1024 * 1 // 1MB diff --git a/server/dispatcher.go b/server/dispatcher.go index 40402d4..ef9c2a0 100644 --- a/server/dispatcher.go +++ b/server/dispatcher.go @@ -11,7 +11,7 @@ var dispatcherUnique = struct{}{} func generateDispatcher(name string, handler func(dispatcher *dispatcher, message *Message)) *dispatcher { return &dispatcher{ name: name, - buffer: buffer.NewUnboundedN[*Message](), + buffer: buffer.NewUnbounded[*Message](), handler: handler, uniques: haxmap.New[string, struct{}](), } diff --git a/server/options.go b/server/options.go index 3df0b0b..0793d12 100644 --- a/server/options.go +++ b/server/options.go @@ -33,7 +33,6 @@ type runtime struct { deadlockDetect time.Duration // 是否开启死锁检测 supportMessageTypes map[int]bool // websocket模式下支持的消息类型 certFile, keyFile string // TLS文件 - messagePoolSize int // 消息池大小 tickerPool *timer.Pool // 定时器池 ticker *timer.Ticker // 定时器 tickerAutonomy bool // 定时器是否独立运行 @@ -211,18 +210,6 @@ func WithWebsocketMessageType(messageTypes ...int) Option { } } -// WithMessageBufferSize 通过特定的消息缓冲池大小运行服务器 -// - 默认大小为 DefaultMessageBufferSize -// - 消息数量超出这个值的时候,消息处理将会造成更大的开销(频繁创建新的结构体),同时服务器将输出警告内容 -func WithMessageBufferSize(size int) Option { - return func(srv *Server) { - if size <= 0 { - size = 1024 - } - srv.messagePoolSize = size - } -} - // WithPProf 通过性能分析工具PProf创建服务器 func WithPProf(pattern ...string) Option { return func(srv *Server) { diff --git a/server/server.go b/server/server.go index 2b0f924..025885b 100644 --- a/server/server.go +++ b/server/server.go @@ -33,8 +33,7 @@ import ( func New(network Network, options ...Option) *Server { server := &Server{ runtime: &runtime{ - messagePoolSize: DefaultMessageBufferSize, - packetWarnSize: DefaultPacketWarnSize, + packetWarnSize: DefaultPacketWarnSize, }, option: &option{}, network: network, @@ -101,7 +100,6 @@ type Server struct { systemSignal chan os.Signal // 系统信号 closeChannel chan struct{} // 关闭信号 multipleRuntimeErrorChan chan error // 多服务器模式下的运行时错误 - messageLock sync.RWMutex // 消息锁 dispatcherLock sync.RWMutex // 消息分发器锁 isShutdown atomic.Bool // 是否已关闭 messageCounter atomic.Int64 // 消息计数器 @@ -134,43 +132,29 @@ func (slf *Server) Run(addr string) error { slf.addr = addr slf.startMessageStatistics() slf.systemDispatcher = generateDispatcher(serverSystemDispatcher, slf.dispatchMessage) - var protoAddr = fmt.Sprintf("%s://%s", slf.network, slf.addr) - var messageInitFinish = make(chan struct{}, 1) - var connectionInitHandle = func(callback func()) { - slf.messageLock.Lock() - slf.messagePool = concurrent.NewPool[*Message](slf.messagePoolSize, - func() *Message { - return &Message{} - }, - func(data *Message) { - data.reset() - }, - ) - slf.messageLock.Unlock() - if slf.network != NetworkHttp && slf.network != NetworkWebsocket && slf.network != NetworkGRPC { - slf.gServer = &gNet{Server: slf} - } - if callback != nil { - go callback() - } - go func() { - messageInitFinish <- struct{}{} - slf.systemDispatcher.start() - }() + slf.messagePool = concurrent.NewPool[Message]( + func() *Message { + return &Message{} + }, + func(data *Message) { + data.reset() + }, + ) + if slf.network != NetworkHttp && slf.network != NetworkWebsocket && slf.network != NetworkGRPC { + slf.gServer = &gNet{Server: slf} } + var protoAddr = fmt.Sprintf("%s://%s", slf.network, slf.addr) + go slf.systemDispatcher.start() switch slf.network { case NetworkNone: - go connectionInitHandle(func() { - slf.isRunning = true - slf.OnStartBeforeEvent() - }) + slf.isRunning = true + slf.OnStartBeforeEvent() case NetworkGRPC: listener, err := net.Listen(string(NetworkTcp), slf.addr) if err != nil { return err } - go connectionInitHandle(nil) go func() { slf.isRunning = true slf.OnStartBeforeEvent() @@ -180,60 +164,56 @@ func (slf *Server) Run(addr string) error { } }() case NetworkTcp, NetworkTcp4, NetworkTcp6, NetworkUdp, NetworkUdp4, NetworkUdp6, NetworkUnix: - go connectionInitHandle(func() { - slf.isRunning = true - slf.OnStartBeforeEvent() - if err := gnet.Serve(slf.gServer, protoAddr, - gnet.WithLogger(new(logger.GNet)), - gnet.WithTicker(true), - gnet.WithMulticore(true), - ); err != nil { - slf.isRunning = false - slf.PushErrorMessage(err, MessageErrorActionShutdown) - } - }) + slf.isRunning = true + slf.OnStartBeforeEvent() + if err := gnet.Serve(slf.gServer, protoAddr, + gnet.WithLogger(new(logger.GNet)), + gnet.WithTicker(true), + gnet.WithMulticore(true), + ); err != nil { + slf.isRunning = false + slf.PushErrorMessage(err, MessageErrorActionShutdown) + } case NetworkKcp: listener, err := kcp.ListenWithOptions(slf.addr, nil, 0, 0) if err != nil { return err } - go connectionInitHandle(func() { - slf.isRunning = true - slf.OnStartBeforeEvent() - for { - session, err := listener.AcceptKCP() - if err != nil { - continue - } - - conn := newKcpConn(slf, session) - slf.OnConnectionOpenedEvent(conn) - - go func(conn *Conn) { - defer func() { - if err := recover(); err != nil { - e, ok := err.(error) - if !ok { - e = fmt.Errorf("%v", err) - } - conn.Close(e) - } - }() - - buf := make([]byte, 4096) - for !conn.IsClosed() { - n, err := conn.kcp.Read(buf) - if err != nil { - if conn.IsClosed() { - break - } - panic(err) - } - slf.PushPacketMessage(conn, 0, buf[:n]) - } - }(conn) + slf.isRunning = true + slf.OnStartBeforeEvent() + for { + session, err := listener.AcceptKCP() + if err != nil { + continue } - }) + + conn := newKcpConn(slf, session) + slf.OnConnectionOpenedEvent(conn) + + go func(conn *Conn) { + defer func() { + if err := recover(); err != nil { + e, ok := err.(error) + if !ok { + e = fmt.Errorf("%v", err) + } + conn.Close(e) + } + }() + + buf := make([]byte, 4096) + for !conn.IsClosed() { + n, err := conn.kcp.Read(buf) + if err != nil { + if conn.IsClosed() { + break + } + panic(err) + } + slf.PushPacketMessage(conn, 0, buf[:n]) + } + }(conn) + } case NetworkHttp: go func() { slf.isRunning = true @@ -248,7 +228,6 @@ func (slf *Server) Run(addr string) error { log.String("ip", c.ClientIP()), log.String("path", c.Request.URL.Path), log.Duration("cost", time.Since(t))) }) - go connectionInitHandle(nil) if len(slf.certFile)+len(slf.keyFile) > 0 { if err := slf.httpServer.ListenAndServeTLS(slf.certFile, slf.keyFile); err != nil { slf.isRunning = false @@ -263,94 +242,92 @@ func (slf *Server) Run(addr string) error { }() case NetworkWebsocket: - go connectionInitHandle(func() { - var pattern string - var index = strings.Index(addr, "/") - if index == -1 { - pattern = "/" - } else { - pattern = addr[index:] - slf.addr = slf.addr[:index] + var pattern string + var index = strings.Index(addr, "/") + if index == -1 { + pattern = "/" + } else { + pattern = addr[index:] + slf.addr = slf.addr[:index] + } + var upgrade = websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) { + ip := request.Header.Get("X-Real-IP") + ws, err := upgrade.Upgrade(writer, request, nil) + if err != nil { + return } - var upgrade = websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - CheckOrigin: func(r *http.Request) bool { - return true - }, + if len(ip) == 0 { + addr := ws.RemoteAddr().String() + if index := strings.LastIndex(addr, ":"); index != -1 { + ip = addr[0:index] + } } - http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) { - ip := request.Header.Get("X-Real-IP") - ws, err := upgrade.Upgrade(writer, request, nil) - if err != nil { - return - } - if len(ip) == 0 { - addr := ws.RemoteAddr().String() - if index := strings.LastIndex(addr, ":"); index != -1 { - ip = addr[0:index] - } - } - if slf.websocketCompression > 0 { - _ = ws.SetCompressionLevel(slf.websocketCompression) - } - ws.EnableWriteCompression(slf.websocketWriteCompression) - conn := newWebsocketConn(slf, ws, ip) - conn.SetData(wsRequestKey, request) - for k, v := range request.URL.Query() { - if len(v) == 1 { - conn.SetData(k, v[0]) - } else { - conn.SetData(k, v) - } - } - slf.OnConnectionOpenedEvent(conn) - - defer func() { - if err := recover(); err != nil { - e, ok := err.(error) - if !ok { - e = fmt.Errorf("%v", err) - } - conn.Close(e) - } - }() - for !conn.IsClosed() { - if slf.websocketReadDeadline > 0 { - if err := ws.SetReadDeadline(time.Now().Add(slf.websocketReadDeadline)); err != nil { - panic(err) - } - } - messageType, packet, readErr := ws.ReadMessage() - if readErr != nil { - if conn.IsClosed() { - break - } - panic(readErr) - } - if len(slf.supportMessageTypes) > 0 && !slf.supportMessageTypes[messageType] { - panic(ErrWebsocketIllegalMessageType) - } - slf.PushPacketMessage(conn, messageType, packet) - } - }) - go func() { - slf.isRunning = true - slf.OnStartBeforeEvent() - if len(slf.certFile)+len(slf.keyFile) > 0 { - if err := http.ListenAndServeTLS(slf.addr, slf.certFile, slf.keyFile, nil); err != nil { - slf.isRunning = false - slf.PushErrorMessage(err, MessageErrorActionShutdown) - } + if slf.websocketCompression > 0 { + _ = ws.SetCompressionLevel(slf.websocketCompression) + } + ws.EnableWriteCompression(slf.websocketWriteCompression) + conn := newWebsocketConn(slf, ws, ip) + conn.SetData(wsRequestKey, request) + for k, v := range request.URL.Query() { + if len(v) == 1 { + conn.SetData(k, v[0]) } else { - if err := http.ListenAndServe(slf.addr, nil); err != nil { - slf.isRunning = false - slf.PushErrorMessage(err, MessageErrorActionShutdown) + conn.SetData(k, v) + } + } + slf.OnConnectionOpenedEvent(conn) + + defer func() { + if err := recover(); err != nil { + e, ok := err.(error) + if !ok { + e = fmt.Errorf("%v", err) + } + conn.Close(e) + } + }() + for !conn.IsClosed() { + if slf.websocketReadDeadline > 0 { + if err := ws.SetReadDeadline(time.Now().Add(slf.websocketReadDeadline)); err != nil { + panic(err) } } - - }() + messageType, packet, readErr := ws.ReadMessage() + if readErr != nil { + if conn.IsClosed() { + break + } + panic(readErr) + } + if len(slf.supportMessageTypes) > 0 && !slf.supportMessageTypes[messageType] { + panic(ErrWebsocketIllegalMessageType) + } + slf.PushPacketMessage(conn, messageType, packet) + } }) + go func() { + slf.isRunning = true + slf.OnStartBeforeEvent() + if len(slf.certFile)+len(slf.keyFile) > 0 { + if err := http.ListenAndServeTLS(slf.addr, slf.certFile, slf.keyFile, nil); err != nil { + slf.isRunning = false + slf.PushErrorMessage(err, MessageErrorActionShutdown) + } + } else { + if err := http.ListenAndServe(slf.addr, nil); err != nil { + slf.isRunning = false + slf.PushErrorMessage(err, MessageErrorActionShutdown) + } + } + + }() default: return ErrCanNotSupportNetwork } @@ -359,9 +336,6 @@ func (slf *Server) Run(addr string) error { kcp.SystemTimedSched.Close() } - <-messageInitFinish - close(messageInitFinish) - messageInitFinish = nil if slf.multiple == nil { ip, _ := network.IP() log.Info("Server", log.String(serverMark, "====================================================================")) @@ -652,7 +626,7 @@ func (slf *Server) releaseDispatcher(conn *Conn) { // pushMessage 向服务器中写入特定类型的消息,需严格遵守消息属性要求 func (slf *Server) pushMessage(message *Message) { - if slf.messagePool.IsClose() || !slf.OnMessageExecBeforeEvent(message) { + if !slf.OnMessageExecBeforeEvent(message) { slf.messagePool.Release(message) return } @@ -863,7 +837,7 @@ func (slf *Server) PushShuntAsyncCallbackMessage(conn *Conn, err error, callback func (slf *Server) PushPacketMessage(conn *Conn, wst int, packet []byte, mark ...log.Field) { slf.pushMessage(slf.messagePool.Get().castToPacketMessage( &Conn{wst: wst, connection: conn.connection}, - packet, + packet, mark..., )) } diff --git a/server/server_test.go b/server/server_test.go index eb998a6..271e693 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -11,7 +11,7 @@ import ( func TestNew(t *testing.T) { //limiter := rate.NewLimiter(rate.Every(time.Second), 100) - srv := server.New(server.NetworkWebsocket, server.WithTicker(-1, 200, 10, false), server.WithMessageBufferSize(1024*1024), server.WithPProf()) + srv := server.New(server.NetworkWebsocket, server.WithTicker(-1, 200, 10, false), server.WithPProf()) //srv.RegMessageExecBeforeEvent(func(srv *server.Server, message *server.Message) bool { // t, c := srv.TimeoutContext(time.Second * 5) // defer c()