feat: server 包新增 WithWebsocketUpgrade 函数,支持自定义 websocket.Upgrader

This commit is contained in:
kercylan98 2023-12-25 14:40:02 +08:00
parent 7efe88a0f4
commit e960d07f49
4 changed files with 74 additions and 56 deletions

View File

@ -1,6 +1,8 @@
package server package server
import ( import (
"github.com/gorilla/websocket"
"net/http"
"time" "time"
) )
@ -17,3 +19,13 @@ const (
DefaultDispatcherBufferSize = 1024 * 16 DefaultDispatcherBufferSize = 1024 * 16
DefaultConnWriteBufferSize = 1024 * 1 DefaultConnWriteBufferSize = 1024 * 1
) )
func DefaultWebsocketUpgrader() *websocket.Upgrader {
return &websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
}

View File

@ -2,6 +2,7 @@ package server
import ( import (
"github.com/gin-contrib/pprof" "github.com/gin-contrib/pprof"
"github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/utils/log" "github.com/kercylan98/minotaur/utils/log"
"github.com/kercylan98/minotaur/utils/timer" "github.com/kercylan98/minotaur/utils/timer"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -30,25 +31,38 @@ type option struct {
} }
type runtime struct { type runtime struct {
deadlockDetect time.Duration // 是否开启死锁检测 deadlockDetect time.Duration // 是否开启死锁检测
supportMessageTypes map[int]bool // websocket模式下支持的消息类型 supportMessageTypes map[int]bool // websocket 模式下支持的消息类型
certFile, keyFile string // TLS文件 certFile, keyFile string // TLS文件
tickerPool *timer.Pool // 定时器池 tickerPool *timer.Pool // 定时器池
ticker *timer.Ticker // 定时器 ticker *timer.Ticker // 定时器
tickerAutonomy bool // 定时器是否独立运行 tickerAutonomy bool // 定时器是否独立运行
connTickerSize int // 连接定时器大小 connTickerSize int // 连接定时器大小
websocketReadDeadline time.Duration // websocket连接超时时间 websocketReadDeadline time.Duration // websocket 连接超时时间
websocketCompression int // websocket压缩等级 websocketCompression int // websocket 压缩等级
websocketWriteCompression bool // websocket写入压缩 websocketWriteCompression bool // websocket 写入压缩
limitLife time.Duration // 限制最大生命周期 limitLife time.Duration // 限制最大生命周期
packetWarnSize int // 数据包大小警告 packetWarnSize int // 数据包大小警告
messageStatisticsDuration time.Duration // 消息统计时长 messageStatisticsDuration time.Duration // 消息统计时长
messageStatisticsLimit int // 消息统计数量 messageStatisticsLimit int // 消息统计数量
messageStatistics []*atomic.Int64 // 消息统计数量 messageStatistics []*atomic.Int64 // 消息统计数量
messageStatisticsLock *sync.RWMutex // 消息统计锁 messageStatisticsLock *sync.RWMutex // 消息统计锁
dispatcherBufferSize int // 消息分发器缓冲区大小 dispatcherBufferSize int // 消息分发器缓冲区大小
connWriteBufferSize int // 连接写入缓冲区大小 connWriteBufferSize int // 连接写入缓冲区大小
disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道 disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道
websocketUpgrader *websocket.Upgrader // websocket 升级器
}
// WithWebsocketUpgrade 通过指定 websocket.Upgrader 的方式创建服务器
// - 默认值为 DefaultWebsocketUpgrader
// - 该选项仅在创建 NetworkWebsocket 服务器时有效
func WithWebsocketUpgrade(upgrader *websocket.Upgrader) Option {
return func(srv *Server) {
if srv.network != NetworkWebsocket {
return
}
srv.websocketUpgrader = upgrader
}
} }
// WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器 // WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/server/internal/logger" "github.com/kercylan98/minotaur/server/internal/logger"
"github.com/kercylan98/minotaur/utils/concurrent" "github.com/kercylan98/minotaur/utils/concurrent"
"github.com/kercylan98/minotaur/utils/log" "github.com/kercylan98/minotaur/utils/log"
@ -196,12 +195,8 @@ func (slf *Server) Run(addr string) error {
go func(conn *Conn) { go func(conn *Conn) {
defer func() { defer func() {
if err := recover(); err != nil { if err := super.RecoverTransform(recover()); err != nil {
e, ok := err.(error) conn.Close(err)
if !ok {
e = fmt.Errorf("%v", err)
}
conn.Close(e)
} }
}() }()
@ -254,16 +249,12 @@ func (slf *Server) Run(addr string) error {
pattern = addr[index:] pattern = addr[index:]
slf.addr = slf.addr[:index] slf.addr = slf.addr[:index]
} }
var upgrade = websocket.Upgrader{ if slf.websocketUpgrader == nil {
ReadBufferSize: 4096, slf.websocketUpgrader = DefaultWebsocketUpgrader()
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
} }
http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) { http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) {
ip := request.Header.Get("X-Real-IP") ip := request.Header.Get("X-Real-IP")
ws, err := upgrade.Upgrade(writer, request, nil) ws, err := slf.websocketUpgrader.Upgrade(writer, request, nil)
if err != nil { if err != nil {
return return
} }
@ -289,12 +280,8 @@ func (slf *Server) Run(addr string) error {
slf.OnConnectionOpenedEvent(conn) slf.OnConnectionOpenedEvent(conn)
defer func() { defer func() {
if err := recover(); err != nil { if err := super.RecoverTransform(recover()); err != nil {
e, ok := err.(error) conn.Close(err)
if !ok {
e = fmt.Errorf("%v", err)
}
conn.Close(e)
} }
}() }()
for !conn.IsClosed() { for !conn.IsClosed() {
@ -734,15 +721,11 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) {
if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync { if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync {
defer func(msg *Message) { defer func(msg *Message) {
super.Handle(cancel) super.Handle(cancel)
if err := recover(); err != nil { if err := super.RecoverTransform(recover()); err != nil {
stack := string(debug.Stack()) stack := string(debug.Stack())
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.String("Info", msg.String()), log.Any("error", err), log.String("stack", stack)) log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.String("Info", msg.String()), log.Any("error", err), log.String("stack", stack))
fmt.Println(stack) fmt.Println(stack)
e, ok := err.(error) slf.OnMessageErrorEvent(msg, err)
if !ok {
e = fmt.Errorf("%v", err)
}
slf.OnMessageErrorEvent(msg, e)
} }
if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback { if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback {
dispatcher.antiUnique(msg.name) dispatcher.antiUnique(msg.name)
@ -782,18 +765,14 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) {
case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync: case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync:
if err := slf.ants.Submit(func() { if err := slf.ants.Submit(func() {
defer func() { defer func() {
if err := recover(); err != nil { if err := super.RecoverTransform(recover()); err != nil {
if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync { if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync {
dispatcher.antiUnique(msg.name) dispatcher.antiUnique(msg.name)
} }
stack := string(debug.Stack()) stack := string(debug.Stack())
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack)) log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack))
fmt.Println(stack) fmt.Println(stack)
e, ok := err.(error) slf.OnMessageErrorEvent(msg, err)
if !ok {
e = fmt.Errorf("%v", err)
}
slf.OnMessageErrorEvent(msg, e)
} }
super.Handle(cancel) super.Handle(cancel)
slf.low(msg, present, time.Second) slf.low(msg, present, time.Second)

View File

@ -1,21 +1,21 @@
package server_test package server_test
import ( import (
"fmt"
"github.com/kercylan98/minotaur/server" "github.com/kercylan98/minotaur/server"
"testing" "testing"
"time" "time"
) )
type TestService struct { type TestService struct{}
}
func (ts *TestService) OnInit(srv *server.Server) { func (ts *TestService) OnInit(srv *server.Server) {
srv.RegStartFinishEvent(func(srv *server.Server) { srv.RegStartFinishEvent(func(srv *server.Server) {
println("Server started") fmt.Println("server start finish")
}) })
srv.RegStopEvent(func(srv *server.Server) { srv.RegStopEvent(func(srv *server.Server) {
println("Server stopped") fmt.Println("server stop")
}) })
} }
@ -25,6 +25,19 @@ func TestBindService(t *testing.T) {
server.BindService(srv, new(TestService)) server.BindService(srv, new(TestService))
if err := srv.RunNone(); err != nil { if err := srv.RunNone(); err != nil {
panic(err) t.Fatal(err)
} }
} }
func ExampleBindService() {
srv := server.New(server.NetworkNone, server.WithLimitLife(time.Second))
server.BindService(srv, new(TestService))
if err := srv.RunNone(); err != nil {
panic(err)
}
// Output:
// server start finish
// server stop
}