diff --git a/server/bot.go b/server/bot.go new file mode 100644 index 0000000..d8d0e6c --- /dev/null +++ b/server/bot.go @@ -0,0 +1,69 @@ +package server + +import ( + "fmt" + "io" + "sync/atomic" + "time" +) + +// NewBot 创建一个机器人,目前仅支持 Socket 服务器 +func NewBot(srv *Server, options ...BotOption) *Bot { + if !srv.IsSocket() { + panic(fmt.Errorf("server type[%s] is not socket", srv.network)) + } + bot := &Bot{ + conn: newBotConn(srv), + } + for _, option := range options { + option(bot) + } + return bot +} + +type Bot struct { + conn *Conn + joined atomic.Bool +} + +// JoinServer 加入服务器 +func (slf *Bot) JoinServer() { + if slf.joined.Swap(true) { + slf.conn.server.OnConnectionClosedEvent(slf.conn, nil) + } + slf.conn.server.OnConnectionOpenedEvent(slf.conn) +} + +// LeaveServer 离开服务器 +func (slf *Bot) LeaveServer() { + if slf.joined.Swap(false) { + slf.conn.server.OnConnectionClosedEvent(slf.conn, nil) + } +} + +// SetNetworkDelay 设置网络延迟和波动范围 +// - delay 延迟 +// - fluctuation 波动范围 +func (slf *Bot) SetNetworkDelay(delay, fluctuation time.Duration) { + slf.conn.delay = delay + slf.conn.fluctuation = fluctuation +} + +// SetWriter 设置写入器 +func (slf *Bot) SetWriter(writer io.Writer) { + slf.conn.botWriter.Store(&writer) +} + +// SendPacket 发送数据包到服务器 +func (slf *Bot) SendPacket(packet []byte) { + if slf.conn.server.IsOnline(slf.conn.GetID()) { + slf.conn.server.PushPacketMessage(slf.conn, 0, packet) + } +} + +// SendWSPacket 发送 WebSocket 数据包到服务器 +func (slf *Bot) SendWSPacket(wst int, packet []byte) { + if slf.conn.server.IsOnline(slf.conn.GetID()) { + slf.conn.server.PushPacketMessage(slf.conn, wst, packet) + } +} diff --git a/server/bot_options.go b/server/bot_options.go new file mode 100644 index 0000000..5796a91 --- /dev/null +++ b/server/bot_options.go @@ -0,0 +1,26 @@ +package server + +import ( + "io" + "time" +) + +type BotOption func(bot *Bot) + +// WithBotNetworkDelay 设置机器人网络延迟及波动范围 +// - delay 延迟 +// - fluctuation 波动范围 +func WithBotNetworkDelay(delay, fluctuation time.Duration) BotOption { + return func(bot *Bot) { + bot.conn.delay = delay + bot.conn.fluctuation = fluctuation + } +} + +// WithBotWriter 设置机器人写入器,默认为 os.Stdout +func WithBotWriter(construction func(bot *Bot) io.Writer) BotOption { + return func(bot *Bot) { + writer := construction(bot) + bot.conn.botWriter.Store(&writer) + } +} diff --git a/server/bot_test.go b/server/bot_test.go new file mode 100644 index 0000000..94d10d8 --- /dev/null +++ b/server/bot_test.go @@ -0,0 +1,49 @@ +package server_test + +import ( + "github.com/kercylan98/minotaur/server" + "io" + "testing" + "time" +) + +type Writer struct { + t *testing.T + bot *server.Bot +} + +func (slf *Writer) Write(p []byte) (n int, err error) { + slf.t.Log(string(p)) + switch string(p) { + case "hello": + slf.bot.SendPacket([]byte("world")) + } + return 0, nil +} + +func TestNewBot(t *testing.T) { + srv := server.New(server.NetworkWebsocket) + + srv.RegConnectionOpenedEvent(func(srv *server.Server, conn *server.Conn) { + t.Logf("connection opened: %s", conn.GetID()) + conn.Close() + conn.Write([]byte("hello")) + }) + srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) { + t.Logf("connection closed: %s", conn.GetID()) + }) + srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { + t.Logf("connection %s receive packet: %s", conn.GetID(), string(packet)) + conn.Write([]byte("world")) + }) + srv.RegStartFinishEvent(func(srv *server.Server) { + bot := server.NewBot(srv, server.WithBotNetworkDelay(100, 20), server.WithBotWriter(func(bot *server.Bot) io.Writer { + return &Writer{t: t, bot: bot} + })) + bot.JoinServer() + time.Sleep(time.Second) + bot.SendPacket([]byte("hello")) + }) + + srv.Run(":9600") +} diff --git a/server/conn.go b/server/conn.go index df55cbd..e474faf 100644 --- a/server/conn.go +++ b/server/conn.go @@ -13,11 +13,14 @@ import ( "github.com/kercylan98/minotaur/utils/timer" "github.com/panjf2000/gnet" "github.com/xtaci/kcp-go/v5" + "io" "net" "net/http" + "os" "runtime/debug" "strings" "sync" + "sync/atomic" "time" ) @@ -80,18 +83,25 @@ func newWebsocketConn(server *Server, ws *websocket.Conn, ip string) *Conn { return c } -// NewEmptyConn 创建一个适用于测试的空连接 -func NewEmptyConn(server *Server) *Conn { +// newBotConn 创建一个适用于测试等情况的机器人连接 +func newBotConn(server *Server) *Conn { + ip, port := random.NetIP(), random.Port() + var writer io.Writer = os.Stdout c := &Conn{ ctx: server.ctx, connection: &connection{ - server: server, - remoteAddr: &net.TCPAddr{}, - ip: "0.0.0.0:0", - data: map[any]any{}, - openTime: time.Now(), + server: server, + remoteAddr: &net.TCPAddr{ + IP: ip, + Port: port, + Zone: "", + }, + ip: fmt.Sprintf("BOT:%s:%d", ip.String(), port), + data: map[any]any{}, + openTime: time.Now(), }, } + c.botWriter.Store(&writer) c.init() return c } @@ -104,20 +114,23 @@ type Conn struct { // connection 长久保持的连接 type connection struct { - server *Server - ticker *timer.Ticker - remoteAddr net.Addr - ip string - ws *websocket.Conn - gn gnet.Conn - kcp *kcp.UDPSession - gw func(packet []byte) - data map[any]any - closed bool - pool *concurrent.Pool[*connPacket] - loop *writeloop.WriteLoop[*connPacket] - mu sync.Mutex - openTime time.Time + server *Server + ticker *timer.Ticker + remoteAddr net.Addr + ip string + ws *websocket.Conn + gn gnet.Conn + kcp *kcp.UDPSession + gw func(packet []byte) + data map[any]any + closed bool + pool *concurrent.Pool[*connPacket] + loop *writeloop.WriteLoop[*connPacket] + mu sync.Mutex + openTime time.Time + delay time.Duration + fluctuation time.Duration + botWriter atomic.Pointer[io.Writer] } // Ticker 获取定时器 @@ -145,8 +158,8 @@ func (slf *Conn) GetWebsocketRequest() *http.Request { return slf.GetData(wsRequestKey).(*http.Request) } -// IsEmpty 是否是空连接 -func (slf *Conn) IsEmpty() bool { +// IsBot 是否是机器人连接 +func (slf *Conn) IsBot() bool { return slf.ws == nil && slf.gn == nil && slf.kcp == nil && slf.gw == nil } @@ -158,6 +171,9 @@ func (slf *Conn) RemoteAddr() net.Addr { // GetID 获取连接ID // - 为远程地址的字符串形式 func (slf *Conn) GetID() string { + if slf.IsBot() { + return slf.ip + } return slf.remoteAddr.String() } @@ -281,6 +297,14 @@ func (slf *Conn) init() { ) slf.loop = writeloop.NewWriteLoop[*connPacket](slf.pool, func(data *connPacket) error { var err error + if slf.delay > 0 || slf.fluctuation > 0 { + time.Sleep(random.Duration(int64(slf.delay-slf.fluctuation), int64(slf.delay+slf.fluctuation))) + _, err = (*slf.botWriter.Load()).Write(data.packet) + if data.callback != nil { + data.callback(err) + } + return err + } if slf.IsWebsocket() { err = slf.ws.WriteMessage(data.wst, data.packet) } else { diff --git a/server/server.go b/server/server.go index 2aebc2f..6f32284 100644 --- a/server/server.go +++ b/server/server.go @@ -387,6 +387,13 @@ func (slf *Server) Run(addr string) error { return nil } +// IsSocket 是否是 Socket 模式 +func (slf *Server) IsSocket() bool { + return slf.network == NetworkTcp || slf.network == NetworkTcp4 || slf.network == NetworkTcp6 || + slf.network == NetworkUdp || slf.network == NetworkUdp4 || slf.network == NetworkUdp6 || + slf.network == NetworkUnix || slf.network == NetworkKcp || slf.network == NetworkWebsocket +} + // RunNone 是 Run("") 的简写,仅适用于运行 NetworkNone 服务器 func (slf *Server) RunNone() error { return slf.Run(str.None) @@ -407,6 +414,18 @@ func (slf *Server) GetOnlineCount() int { return slf.online.Size() } +// GetOnlineBotCount 获取在线机器人数量 +func (slf *Server) GetOnlineBotCount() int { + var count int + slf.online.Range(func(id string, conn *Conn) bool { + if conn.IsBot() { + count++ + } + return true + }) + return count +} + // GetOnline 获取在线连接 func (slf *Server) GetOnline(id string) *Conn { return slf.online.Get(id) diff --git a/utils/random/ip.go b/utils/random/ip.go new file mode 100644 index 0000000..952552f --- /dev/null +++ b/utils/random/ip.go @@ -0,0 +1,26 @@ +package random + +import ( + "fmt" + "net" +) + +// NetIP 返回一个随机的IP地址 +func NetIP() net.IP { + return net.IPv4(byte(Int64(0, 255)), byte(Int64(0, 255)), byte(Int64(0, 255)), byte(Int64(0, 255))) +} + +// Port 返回一个随机的端口号 +func Port() int { + return Int(1, 65535) +} + +// IPv4 返回一个随机产生的IPv4地址。 +func IPv4() string { + return fmt.Sprintf("%d.%d.%d.%d", Int(1, 255), Int(0, 255), Int(0, 255), Int(0, 255)) +} + +// IPv4Port 返回一个随机产生的IPv4地址和端口。 +func IPv4Port() string { + return fmt.Sprintf("%d.%d.%d.%d:%d", Int(1, 255), Int(0, 255), Int(0, 255), Int(0, 255), Int(1, 65535)) +}