diff --git a/server/conn.go b/server/conn.go index 9d9784d..555a39e 100644 --- a/server/conn.go +++ b/server/conn.go @@ -297,7 +297,7 @@ func (slf *Conn) init() { ) slf.loop = writeloop.NewChannel[*connPacket](slf.pool, slf.server.connWriteBufferSize, func(data *connPacket) error { if slf.server.runtime.packetWarnSize > 0 && len(data.packet) > slf.server.runtime.packetWarnSize { - log.Warn("Conn.Write", log.String("State", "PacketWarn"), log.String("Reason", "PacketSize"), log.String("ID", slf.GetID()), log.Int("PacketSize", len(data.packet))) + log.Warn("Conn.Put", log.String("State", "PacketWarn"), log.String("Reason", "PacketSize"), log.String("ID", slf.GetID()), log.Int("PacketSize", len(data.packet))) } var err error if slf.delay > 0 || slf.fluctuation > 0 { diff --git a/server/v2/actor/actor.go b/server/v2/actor/actor.go new file mode 100644 index 0000000..bb97184 --- /dev/null +++ b/server/v2/actor/actor.go @@ -0,0 +1,153 @@ +package actor + +import ( + "context" + "github.com/kercylan98/minotaur/utils/buffer" + "github.com/kercylan98/minotaur/utils/super" + "sync" + "time" +) + +// MessageHandler 定义了处理消息的函数类型 +type MessageHandler[M any] func(message M) + +// NewActor 创建一个新的 Actor,并启动其消息处理循环 +func NewActor[M any](ctx context.Context, handler MessageHandler[M]) *Actor[M] { + a := newActor(ctx, handler) + a.counter = new(super.Counter[int]) + go a.run() + return a +} + +// newActor 创建一个新的 Actor +func newActor[M any](ctx context.Context, handler MessageHandler[M]) (actor *Actor[M]) { + a := &Actor[M]{ + buf: buffer.NewRing[M](1024), + handler: handler, + } + a.cond = sync.NewCond(&a.rw) + a.ctx, a.cancel = context.WithCancel(ctx) + return a +} + +// Actor 是一个消息驱动的并发实体 +type Actor[M any] struct { + idx int // Actor 在其父 Actor 中的索引 + ctx context.Context // Actor 的上下文 + cancel context.CancelFunc // 用于取消 Actor 的函数 + buf *buffer.Ring[M] // 用于缓存消息的环形缓冲区 + handler MessageHandler[M] // 处理消息的函数 + rw sync.RWMutex // 读写锁,用于保护 Actor 的并发访问 + cond *sync.Cond // 条件变量,用于触发消息处理流程 + counter *super.Counter[int] // 消息计数器,用于统计处理的消息数量 + dying bool // 标识 Actor 是否正在关闭中 + parent *Actor[M] // 父 Actor + subs []*Actor[M] // 子 Actor 切片 + gap []int // 用于记录已经关闭的子 Actor 的索引位置,以便复用 +} + +// run 启动 Actor 的消息处理循环 +func (a *Actor[M]) run() { + var ctx = a.ctx + var clearGap = time.NewTicker(time.Second * 30) + defer func(a *Actor[M], clearGap *time.Ticker) { + clearGap.Stop() + a.cancel() + a.parent.removeSub(a) + }(a, clearGap) + for { + select { + case <-a.ctx.Done(): + a.rw.Lock() + if ctx == a.ctx { + a.dying = true + } else { + ctx = a.ctx + } + a.rw.Unlock() + a.cond.Signal() + case <-clearGap.C: + a.rw.Lock() + for _, idx := range a.gap { + a.subs = append(a.subs[:idx], a.subs[idx+1:]...) + } + for idx, sub := range a.subs { + sub.idx = idx + } + a.gap = a.gap[:0] + a.rw.Unlock() + default: + a.rw.Lock() + if a.buf.IsEmpty() { + if a.dying && a.counter.Val() == 0 { + return + } + a.cond.Wait() + } + messages := a.buf.ReadAll() + a.rw.Unlock() + for _, message := range messages { + a.handler(message) + } + a.counter.Add(-len(messages)) + } + } +} + +// Reuse 重用 Actor,Actor 会重新激活 +func (a *Actor[M]) Reuse(ctx context.Context) { + before := a.cancel + defer before() + + a.rw.Lock() + a.ctx, a.cancel = context.WithCancel(ctx) + a.dying = false + for _, sub := range a.subs { + sub.Reuse(a.ctx) + } + a.rw.Unlock() + a.cond.Signal() +} + +// Send 发送消息 +func (a *Actor[M]) Send(message M) { + a.rw.Lock() + a.counter.Add(1) + a.buf.Write(message) + a.rw.Unlock() + a.cond.Signal() +} + +// Sub 派生一个子 Actor,该子 Actor 生命周期将继承父 Actor 的生命周期 +func (a *Actor[M]) Sub() { + a.rw.Lock() + defer a.rw.Unlock() + + sub := newActor(a.ctx, a.handler) + sub.counter = a.counter.Sub() + sub.parent = a + if len(a.gap) > 0 { + sub.idx = a.gap[0] + a.gap = a.gap[1:] + } else { + sub.idx = len(a.subs) + } + a.subs = append(a.subs, sub) + go sub.run() +} + +// removeSub 从父 Actor 中移除指定的子 Actor +func (a *Actor[M]) removeSub(sub *Actor[M]) { + if a == nil { + return + } + + a.rw.Lock() + defer a.rw.Unlock() + if sub.idx == len(a.subs)-1 { + a.subs = a.subs[:sub.idx] + return + } + a.subs[sub.idx] = nil + a.gap = append(a.gap, sub.idx) +} diff --git a/server/v2/conn.go b/server/v2/conn.go index f908e65..bfbb0f6 100644 --- a/server/v2/conn.go +++ b/server/v2/conn.go @@ -2,58 +2,23 @@ package server import ( "context" - "github.com/kercylan98/minotaur/utils/log" + "github.com/kercylan98/minotaur/server/v2/actor" "net" - "unsafe" ) type Conn interface { - net.Conn } -type conn struct { - net.Conn - cs *connections - ctx context.Context - cancel context.CancelFunc - idx int -} - -func (c *conn) init(ctx context.Context, cs *connections, conn net.Conn, idx int) *conn { - c.Conn = conn - c.cs = cs - c.ctx, c.cancel = context.WithCancel(ctx) - c.idx = idx - return c -} - -func (c *conn) awaitRead() { - defer func() { _ = c.Close() }() - - const bufferSize = 4096 - buf := make([]byte, bufferSize) // 避免频繁的内存分配,初始化一个固定大小的缓冲区 - for { - select { - case <-c.ctx.Done(): - return - default: - ptr := unsafe.Pointer(&buf[0]) - n, err := c.Read((*[bufferSize]byte)(ptr)[:]) - if err != nil { - log.Error("READ", err) - return - } - - if n > 0 { - if _, err := c.Write(buf[:n]); err != nil { - log.Error("Write", err) - } - } - } +func newConn(ctx context.Context, c net.Conn, connWriter ConnWriter, handler actor.MessageHandler[Packet]) Conn { + return &conn{ + conn: c, + writer: connWriter, + actor: actor.NewActor[Packet](ctx, handler), } } -func (c *conn) Close() (err error) { - c.cs.Event() <- c - return +type conn struct { + conn net.Conn + writer ConnWriter + actor *actor.Actor[Packet] } diff --git a/server/v2/connections.go b/server/v2/connections.go deleted file mode 100644 index 9681aa7..0000000 --- a/server/v2/connections.go +++ /dev/null @@ -1,111 +0,0 @@ -package server - -import ( - "context" - "github.com/kercylan98/minotaur/utils/log" - "net" - "time" -) - -// connections 结构体用于管理连接 -type connections struct { - ctx context.Context // 上下文对象,用于取消连接管理器 - ch chan any // 事件通道,用于接收连接管理器的操作事件 - items []*conn // 连接列表,存储所有打开的连接 - gap []int // 连接空隙,记录已关闭的连接索引,用于重用索引 -} - -// 初始化连接管理器 -func (cs *connections) init(ctx context.Context) *connections { - cs.ctx = ctx - cs.ch = make(chan any, 1024) - cs.items = make([]*conn, 0, 128) - go cs.awaitRun() - return cs -} - -// 清理连接列表中的空隙 -func (cs *connections) clearGap() { - cs.gap = cs.gap[:0] - var gap = make([]int, 0, len(cs.items)) - for i, c := range cs.items { - if c == nil { - continue - } - c.idx = i - gap = append(gap, i) - } - - cs.gap = gap -} - -// 打开新连接 -func (cs *connections) open(c net.Conn) error { - // 如果存在连接空隙,则重用连接空隙中的索引,否则分配新的索引 - var idx int - var reuse bool - if len(cs.gap) > 0 { - idx = cs.gap[0] - cs.gap = cs.gap[1:] - reuse = true - } else { - idx = len(cs.items) - } - conn := new(conn).init(cs.ctx, cs, c, idx) - if reuse { - cs.items[idx] = conn - } else { - cs.items = append(cs.items, conn) - } - go conn.awaitRead() - return nil -} - -// 关闭连接 -func (cs *connections) close(c *conn) error { - if c == nil { - return nil - } - defer c.cancel() - // 如果连接索引是连接列表的最后一个索引,则直接删除连接对象,否则将连接对象置空,并将索引添加到连接空隙中 - if c.idx == len(cs.items)-1 { - cs.items = cs.items[:c.idx] - } else { - cs.items[c.idx] = nil - cs.gap = append(cs.gap, c.idx) - } - return c.Conn.Close() -} - -// 等待连接管理器的事件并处理 -func (cs *connections) awaitRun() { - clearGapTicker := time.NewTicker(time.Second * 30) - defer clearGapTicker.Stop() - - for { - select { - case <-cs.ctx.Done(): - return - case <-clearGapTicker.C: - cs.clearGap() - case a := <-cs.ch: - var err error - - switch v := a.(type) { - case *conn: - err = cs.close(v) - case net.Conn: - err = cs.open(v) - } - - if err != nil { - log.Error("connections.awaitRun", log.Any("err", err)) - } - } - } -} - -// Event 获取连接管理器的事件通道 -func (cs *connections) Event() chan<- any { - return cs.ch -} diff --git a/server/v2/core.go b/server/v2/core.go deleted file mode 100644 index b1a3272..0000000 --- a/server/v2/core.go +++ /dev/null @@ -1,9 +0,0 @@ -package server - -type Core interface { - connectionManager -} - -type connectionManager interface { - Event() chan<- any -} diff --git a/server/v2/message.go b/server/v2/message.go new file mode 100644 index 0000000..c9487cd --- /dev/null +++ b/server/v2/message.go @@ -0,0 +1,4 @@ +package server + +type message struct { +} diff --git a/server/v2/network.go b/server/v2/network.go index 0b2f287..05ffee9 100644 --- a/server/v2/network.go +++ b/server/v2/network.go @@ -1,11 +1,13 @@ package server -import "context" +import ( + "context" +) type Network interface { - OnSetup(ctx context.Context, core Core) error + OnSetup(ctx context.Context, event NetworkCore) error - OnRun(ctx context.Context) error + OnRun() error OnShutdown() error } diff --git a/server/v2/network/http.go b/server/v2/network/http.go index 36baac3..0faeb33 100644 --- a/server/v2/network/http.go +++ b/server/v2/network/http.go @@ -30,16 +30,18 @@ type httpCore[H http.Handler] struct { addr string handler H srv *http.Server + event server.NetworkCore } -func (h *httpCore[H]) OnSetup(ctx context.Context, core server.Core) (err error) { +func (h *httpCore[H]) OnSetup(ctx context.Context, event server.NetworkCore) (err error) { + h.event = event h.srv.BaseContext = func(listener net.Listener) context.Context { return ctx } return } -func (h *httpCore[H]) OnRun(ctx context.Context) (err error) { +func (h *httpCore[H]) OnRun() (err error) { if err = h.srv.ListenAndServe(); errors.Is(err, http.ErrServerClosed) { err = nil } diff --git a/server/v2/network/websocket.go b/server/v2/network/websocket.go index aa97dea..39b2f03 100644 --- a/server/v2/network/websocket.go +++ b/server/v2/network/websocket.go @@ -3,55 +3,45 @@ package network import ( "context" "fmt" - "github.com/gobwas/ws" "github.com/kercylan98/minotaur/server/v2" - "net/http" + "github.com/kercylan98/minotaur/utils/collection" + "github.com/panjf2000/gnet/v2" + "time" ) -func WebSocket(addr, pattern string) server.Network { - return WebSocketWithHandler[*HttpServe](addr, &HttpServe{ServeMux: http.NewServeMux()}, func(handler *HttpServe, ws http.HandlerFunc) { - handler.Handle(fmt.Sprintf("GET %s", pattern), ws) - }) -} - -func WebSocketWithHandler[H http.Handler](addr string, handler H, upgraderHandlerFunc WebSocketUpgraderHandlerFunc[H]) server.Network { - c := &websocketCore[H]{ - httpCore: HttpWithHandler(addr, handler).(*httpCore[H]), - upgraderHandlerFunc: upgraderHandlerFunc, +func WebSocket(addr string, pattern ...string) server.Network { + ws := &websocketCore{ + addr: addr, + pattern: collection.FindFirstOrDefaultInSlice(pattern, "/"), } - return c + return ws } -type WebSocketUpgraderHandlerFunc[H http.Handler] func(handler H, ws http.HandlerFunc) - -type websocketCore[H http.Handler] struct { - *httpCore[H] - upgraderHandlerFunc WebSocketUpgraderHandlerFunc[H] - core server.Core +type websocketCore struct { + ctx context.Context + core server.NetworkCore + handler *websocketHandler + addr string + pattern string } -func (w *websocketCore[H]) OnSetup(ctx context.Context, core server.Core) (err error) { +func (w *websocketCore) OnSetup(ctx context.Context, core server.NetworkCore) (err error) { + w.ctx = ctx + w.handler = newWebsocketHandler(w) w.core = core - if err = w.httpCore.OnSetup(ctx, core); err != nil { - return - } - w.upgraderHandlerFunc(w.handler, w.onUpgrade) return } -func (w *websocketCore[H]) OnRun(ctx context.Context) error { - return w.httpCore.OnRun(ctx) +func (w *websocketCore) OnRun() (err error) { + err = gnet.Run(w.handler, fmt.Sprintf("tcp://%s", w.addr)) + return } -func (w *websocketCore[H]) OnShutdown() error { - return w.httpCore.OnShutdown() -} - -func (w *websocketCore[H]) onUpgrade(writer http.ResponseWriter, request *http.Request) { - conn, _, _, err := ws.UpgradeHTTP(request, writer) - if err != nil { - return +func (w *websocketCore) OnShutdown() error { + if w.handler.engine != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return w.handler.engine.Stop(ctx) } - - w.core.Event() <- conn + return nil } diff --git a/server/v2/network/websocket_handler.go b/server/v2/network/websocket_handler.go new file mode 100644 index 0000000..63efdea --- /dev/null +++ b/server/v2/network/websocket_handler.go @@ -0,0 +1,94 @@ +package network + +import ( + "errors" + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" + "github.com/kercylan98/minotaur/server/v2" + "github.com/kercylan98/minotaur/utils/log" + "github.com/panjf2000/gnet/v2" + "time" +) + +func newWebsocketHandler(core *websocketCore) *websocketHandler { + return &websocketHandler{ + core: core, + } +} + +type websocketHandler struct { + engine *gnet.Engine + upgrader ws.Upgrader + core *websocketCore +} + +func (w *websocketHandler) OnBoot(eng gnet.Engine) (action gnet.Action) { + w.engine = &eng + w.initUpgrader() + return +} + +func (w *websocketHandler) OnShutdown(eng gnet.Engine) { + +} + +func (w *websocketHandler) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { + wrapper := newWebsocketWrapper(w.core.ctx, c) + c.SetContext(wrapper) + w.core.core.OnConnectionOpened(wrapper.ctx, c, func(message server.Packet) error { + return wsutil.WriteServerMessage(c, message.GetContext().(ws.OpCode), message.GetBytes()) + }) + return +} + +func (w *websocketHandler) OnClose(c gnet.Conn, err error) (action gnet.Action) { + wrapper := c.Context().(*websocketWrapper) + wrapper.cancel() + return +} + +func (w *websocketHandler) OnTraffic(c gnet.Conn) (action gnet.Action) { + wrapper := c.Context().(*websocketWrapper) + + // read to buffer + if err := wrapper.readToBuffer(); err != nil { + log.Error("websocket", log.Err(err)) + return gnet.Close + } + + // check or upgrade + if err := wrapper.upgrade(w.upgrader); err != nil { + log.Error("websocket", log.Err(err)) + return gnet.Close + } + wrapper.active = time.Now() + + // decode + messages, err := wrapper.decode() + if err != nil { + log.Error("websocket", log.Err(err)) + } + + for _, message := range messages { + packet := w.core.core.GeneratePacket(message.Payload) + packet.SetContext(message.OpCode) + w.core.core.OnReceivePacket(packet) + } + + return +} + +func (w *websocketHandler) OnTick() (delay time.Duration, action gnet.Action) { + return +} + +func (w *websocketHandler) initUpgrader() { + w.upgrader = ws.Upgrader{ + OnRequest: func(uri []byte) (err error) { + if string(uri) != w.core.pattern { + err = errors.New("bad request") + } + return + }, + } +} diff --git a/server/v2/network/wsbsocket_wrapper.go b/server/v2/network/wsbsocket_wrapper.go new file mode 100644 index 0000000..f77862c --- /dev/null +++ b/server/v2/network/wsbsocket_wrapper.go @@ -0,0 +1,162 @@ +package network + +import ( + "bytes" + "errors" + "fmt" + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" + "github.com/kercylan98/minotaur/utils/super" + "github.com/panjf2000/gnet/v2" + "golang.org/x/net/context" + "io" + "time" +) + +// newWebsocketWrapper 创建 websocket 包装器 +func newWebsocketWrapper(ctx context.Context, conn gnet.Conn) *websocketWrapper { + wrapper := &websocketWrapper{ + conn: conn, + upgraded: false, + active: time.Now(), + } + wrapper.ctx, wrapper.cancel = context.WithCancel(ctx) + return wrapper +} + +// websocketWrapper websocket 包装器 +type websocketWrapper struct { + ctx context.Context + cancel context.CancelFunc + conn gnet.Conn // 连接 + upgraded bool // 是否已经升级 + hs ws.Handshake // 握手信息 + active time.Time // 活跃时间 + buf bytes.Buffer // 缓冲区 + + header *ws.Header // 当前头部 + cache bytes.Buffer // 缓存的数据 +} + +// readToBuffer 将数据读取到缓冲区 +func (w *websocketWrapper) readToBuffer() error { + size := w.conn.InboundBuffered() + buf := make([]byte, size) + n, err := w.conn.Read(buf) + if err != nil { + return err + } + if n < size { + return fmt.Errorf("incomplete data, read buffer bytes failed! size: %d, read: %d", size, n) + } + w.buf.Write(buf) + return nil +} + +// upgrade 升级 +func (w *websocketWrapper) upgrade(upgrader ws.Upgrader) (err error) { + if w.upgraded { + return + } + + buf := &w.buf + reader := bytes.NewReader(buf.Bytes()) + n := reader.Len() + + w.hs, err = upgrader.Upgrade(super.ReadWriter{ + Reader: reader, + Writer: w.conn, + }) + skip := n - reader.Len() + if err != nil { + if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { //数据不完整,不跳过 buf 中的 skipN 字节(此时 buf 中存放的仅是部分 "handshake data" bytes),下次再尝试读取 + return + } + buf.Next(skip) + return err + } + buf.Next(skip) + w.upgraded = true + return +} + +// decode 解码 +func (w *websocketWrapper) decode() (messages []wsutil.Message, err error) { + if messages, err = w.read(); err != nil { + return + } + var result = make([]wsutil.Message, 0, len(messages)) + for _, message := range messages { + if message.OpCode.IsControl() { + err = wsutil.HandleClientControlMessage(w.conn, message) + if err != nil { + return + } + continue + } + if message.OpCode == ws.OpText || message.OpCode == ws.OpBinary { + result = append(result, message) + } + } + return result, nil +} + +// decode 解码 +func (w *websocketWrapper) read() (messages []wsutil.Message, err error) { + var buf = &w.buf + for { + // 从缓冲区中读取 header 信息并写入到缓存中 + if w.header == nil { + if buf.Len() < ws.MinHeaderSize { + return // 不完整的数据,不做处理 + } + var header ws.Header + if buf.Len() >= ws.MaxHeaderSize { + header, err = ws.ReadHeader(buf) + if err != nil { + return + } + } else { + // 使用新的 reader 尝试读取 header,避免 header 不完整 + reader := bytes.NewReader(buf.Bytes()) + currLen := reader.Len() + header, err = ws.ReadHeader(reader) + skip := currLen - reader.Len() + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return messages, nil + } + buf.Next(skip) + return nil, err + } + buf.Next(skip) + } + + w.header = &header + if err = ws.WriteHeader(&w.cache, header); err != nil { + return nil, err + } + } + + // 将缓冲区数据读出并写入缓存 + if n := int(w.header.Length); n > 0 { + if buf.Len() < n { + return // 不完整的数据,不做处理 + } + + if _, err = io.CopyN(&w.cache, buf, int64(n)); err != nil { + return + } + } + + // 消息已完整,读取数据帧,否则数据将被分割为多个数据帧 + if w.header.Fin { + messages, err = wsutil.ReadClientMessage(&w.cache, messages) + if err != nil { + return + } + w.cache.Reset() + } + w.header = nil + } +} diff --git a/server/v2/network_core.go b/server/v2/network_core.go new file mode 100644 index 0000000..46f398f --- /dev/null +++ b/server/v2/network_core.go @@ -0,0 +1,50 @@ +package server + +import ( + "github.com/kercylan98/minotaur/utils/hub" + "golang.org/x/net/context" + "net" +) + +type ConnWriter func(message Packet) error + +type NetworkCore interface { + OnConnectionOpened(ctx context.Context, conn net.Conn, writer ConnWriter) + + OnConnectionClosed(conn Conn) + + OnReceivePacket(packet Packet) + + GeneratePacket(data []byte) Packet +} + +type networkCore struct { + *server + packetPool *hub.ObjectPool[*packet] +} + +func (ne *networkCore) init(srv *server) *networkCore { + ne.server = srv + ne.packetPool = hub.NewObjectPool(func() *packet { + return new(packet) + }, func(data *packet) { + data.reset() + }) + return ne +} + +func (ne *networkCore) OnConnectionOpened(ctx context.Context, conn net.Conn, writer ConnWriter) { + +} + +func (ne *networkCore) OnConnectionClosed(conn Conn) { + +} + +func (ne *networkCore) OnReceivePacket(packet Packet) { + +} + +func (ne *networkCore) GeneratePacket(data []byte) Packet { + return ne.packetPool.Get().init(data) +} diff --git a/server/v2/packet.go b/server/v2/packet.go new file mode 100644 index 0000000..ab9d03b --- /dev/null +++ b/server/v2/packet.go @@ -0,0 +1,34 @@ +package server + +type Packet interface { + GetBytes() []byte + SetContext(ctx any) + GetContext() any +} + +type packet struct { + ctx any + data []byte +} + +func (m *packet) init(data []byte) Packet { + m.data = data + return m +} + +func (m *packet) reset() { + m.ctx = nil + m.data = m.data[:0] +} + +func (m *packet) GetBytes() []byte { + return m.data +} + +func (m *packet) SetContext(ctx any) { + m.ctx = ctx +} + +func (m *packet) GetContext() any { + return m.ctx +} diff --git a/server/v2/server.go b/server/v2/server.go index e8fa6a4..17852e4 100644 --- a/server/v2/server.go +++ b/server/v2/server.go @@ -3,7 +3,6 @@ package server import ( "context" "github.com/kercylan98/minotaur/utils/super" - "sync" ) type Server interface { @@ -12,48 +11,33 @@ type Server interface { } type server struct { - ctx *super.CancelContext - networks []Network - connections *connections + *networkCore + ctx *super.CancelContext + network Network } -func NewServer(networks ...Network) Server { +func NewServer(network Network) Server { srv := &server{ - ctx: super.WithCancelContext(context.Background()), - networks: networks, + ctx: super.WithCancelContext(context.Background()), + network: network, } - srv.connections = new(connections).init(srv.ctx) + srv.networkCore = new(networkCore).init(srv) return srv } func (s *server) Run() (err error) { - for _, network := range s.networks { - if err = network.OnSetup(s.ctx, s.connections); err != nil { - return - } + if err = s.network.OnSetup(s.ctx, s); err != nil { + return } - var group = new(sync.WaitGroup) - for _, network := range s.networks { - group.Add(1) - go func(ctx *super.CancelContext, group *sync.WaitGroup, network Network) { - defer group.Done() - if err = network.OnRun(ctx); err != nil { - panic(err) - } - }(s.ctx, group, network) + if err = s.network.OnRun(s.ctx); err != nil { + panic(err) } - group.Wait() - return } func (s *server) Shutdown() (err error) { defer s.ctx.Cancel() - for _, network := range s.networks { - if err = network.OnShutdown(); err != nil { - return - } - } + err = s.network.OnShutdown() return } diff --git a/server/v2/server_test.go b/server/v2/server_test.go index c2df0d5..0130dd6 100644 --- a/server/v2/server_test.go +++ b/server/v2/server_test.go @@ -1,25 +1,13 @@ package server_test import ( - "github.com/gin-gonic/gin" "github.com/kercylan98/minotaur/server/v2" "github.com/kercylan98/minotaur/server/v2/network" - "net/http" "testing" ) func TestNewServer(t *testing.T) { - r := gin.Default() - r.GET("/", func(context *gin.Context) { - context.JSON(200, gin.H{ - "ping": "pong", - }) - }) - srv := server.NewServer(network.WebSocketWithHandler(":9999", r, func(handler *gin.Engine, ws http.HandlerFunc) { - handler.GET("/ws", func(context *gin.Context) { - ws(context.Writer, context.Request) - }) - })) + srv := server.NewServer(network.WebSocket(":9999")) if err := srv.Run(); err != nil { panic(err) } diff --git a/utils/super/counter.go b/utils/super/counter.go new file mode 100644 index 0000000..2585945 --- /dev/null +++ b/utils/super/counter.go @@ -0,0 +1,33 @@ +package super + +import ( + "github.com/kercylan98/minotaur/utils/generic" + "sync" +) + +type Counter[T generic.Number] struct { + v T + p *Counter[T] + rw sync.RWMutex +} + +func (c *Counter[T]) Sub() *Counter[T] { + return &Counter[T]{ + p: c, + } +} + +func (c *Counter[T]) Add(delta T) { + c.rw.Lock() + c.v += delta + c.rw.RUnlock() + if c.p != nil { + c.p.Add(delta) + } +} + +func (c *Counter[T]) Val() T { + c.rw.RLock() + defer c.rw.RUnlock() + return c.v +} diff --git a/utils/super/function.go b/utils/super/function.go index d5fe698..651943b 100644 --- a/utils/super/function.go +++ b/utils/super/function.go @@ -21,3 +21,15 @@ func HandleV[V any](v V, f func(v V)) { f(v) } } + +// SafeExec 安全的执行函数,当 f 为 nil 时,不执行,当 f 执行出错时,不会 panic,而是转化为 error 进行返回 +func SafeExec(f func()) (err error) { + if f == nil { + return + } + defer func() { + err = RecoverTransform(recover()) + }() + f() + return +} diff --git a/utils/super/read_writer.go b/utils/super/read_writer.go new file mode 100644 index 0000000..78601d5 --- /dev/null +++ b/utils/super/read_writer.go @@ -0,0 +1,8 @@ +package super + +import "io" + +type ReadWriter struct { + io.Reader + io.Writer +} diff --git a/utils/unbounded/channel_backlog.go b/utils/unbounded/channel_backlog.go new file mode 100644 index 0000000..fd6552d --- /dev/null +++ b/utils/unbounded/channel_backlog.go @@ -0,0 +1,78 @@ +package unbounded + +import ( + "sync" +) + +// NewChannelBacklog 创建一个并发安全的,基于 channel 和缓冲队列实现的无界缓冲区 +// +// 该缓冲区来源于 gRPC 的实现,用于在不使用额外 goroutine 的情况下实现无界缓冲区 +// - 该缓冲区的所有方法都是线程安全的,除了用于同步的互斥锁外,不会阻塞任何东西 +func NewChannelBacklog[V any]() *ChannelBacklog[V] { + return &ChannelBacklog[V]{c: make(chan V, 1)} +} + +// ChannelBacklog 是并发安全的无界缓冲区的实现 +type ChannelBacklog[V any] struct { + c chan V + closed bool + mu sync.Mutex + backlog []V +} + +// Put 将数据放入缓冲区 +func (cb *ChannelBacklog[V]) Put(t V) { + cb.mu.Lock() + defer cb.mu.Unlock() + if cb.closed { + return + } + if len(cb.backlog) == 0 { + select { + case cb.c <- t: + return + default: + } + } + cb.backlog = append(cb.backlog, t) +} + +// Load 将缓冲区中的数据发送到读取通道中,如果缓冲区中没有数据,则不会发送 +// - 在每次 Get 后都应该执行该函数 +func (cb *ChannelBacklog[V]) Load() { + cb.mu.Lock() + defer cb.mu.Unlock() + if cb.closed { + return + } + if len(cb.backlog) > 0 { + select { + case cb.c <- cb.backlog[0]: + cb.backlog = cb.backlog[1:] + default: + } + } +} + +// Get 获取可接收消息的读取通道,需要注意的是,每次读取成功都应该通过 Load 函数将缓冲区中的数据加载到读取通道中 +func (cb *ChannelBacklog[V]) Get() <-chan V { + return cb.c +} + +// Close 关闭 +func (cb *ChannelBacklog[V]) Close() { + cb.mu.Lock() + defer cb.mu.Unlock() + if cb.closed { + return + } + cb.closed = true + close(cb.c) +} + +// IsClosed 是否已关闭 +func (cb *ChannelBacklog[V]) IsClosed() bool { + cb.mu.Lock() + defer cb.mu.Unlock() + return cb.closed +} diff --git a/utils/unbounded/ring.go b/utils/unbounded/ring.go new file mode 100644 index 0000000..1aa75d9 --- /dev/null +++ b/utils/unbounded/ring.go @@ -0,0 +1,94 @@ +package unbounded + +import ( + "github.com/kercylan98/minotaur/utils/buffer" + "github.com/pkg/errors" + "golang.org/x/net/context" + "sync" +) + +// NewRing 创建一个并发安全的基于环形缓冲区实现的无界缓冲区 +func NewRing[T any](ctx context.Context) *Ring[T] { + r := &Ring[T]{ + ctx: ctx, + ring: buffer.NewRing[T](1024), + ch: make(chan T, 1024), + } + r.cond = sync.NewCond(&r.rw) + + go r.run() + return r +} + +// Ring 是并发安全的,基于环形缓冲区实现的无界缓冲区 +type Ring[T any] struct { + ctx context.Context + ring *buffer.Ring[T] + rw sync.RWMutex + cond *sync.Cond + ch chan T + closed bool +} + +// Put 将数据放入缓冲区 +func (r *Ring[T]) Put(v ...T) error { + if len(v) == 0 { + return nil + } + r.rw.Lock() + if r.closed { + r.rw.Unlock() + return errors.New("unbounded ring is closed") + } + for _, t := range v { + r.ring.Write(t) + } + r.rw.Unlock() + r.cond.Signal() + return nil +} + +// Get 获取可接收消息的读取通道 +func (r *Ring[T]) Get() <-chan T { + return r.ch +} + +// Close 关闭缓冲区 +func (r *Ring[T]) Close() { + r.rw.RLock() + r.closed = true + r.rw.RUnlock() + r.cond.Signal() +} + +// IsClosed 是否已关闭 +func (r *Ring[T]) IsClosed() bool { + r.rw.RLock() + defer r.rw.RUnlock() + return r.closed +} + +func (r *Ring[T]) run() { + for { + select { + case <-r.ctx.Done(): + r.Close() + default: + r.rw.Lock() + if r.ring.IsEmpty() { + if r.closed { // 如果已关闭并且没有数据,则关闭通道 + close(r.ch) + r.rw.Unlock() + return + } + // 等待数据 + r.cond.Wait() + } + vs := r.ring.ReadAll() + r.rw.Unlock() + for _, v := range vs { + r.ch <- v + } + } + } +}