feat: server 包新增 WithWebsocketUpgrade 函数,支持自定义 websocket.Upgrader

This commit is contained in:
kercylan98 2023-12-25 14:40:02 +08:00
parent 7efe88a0f4
commit e960d07f49
4 changed files with 74 additions and 56 deletions

View File

@ -1,6 +1,8 @@
package server
import (
"github.com/gorilla/websocket"
"net/http"
"time"
)
@ -17,3 +19,13 @@ const (
DefaultDispatcherBufferSize = 1024 * 16
DefaultConnWriteBufferSize = 1024 * 1
)
func DefaultWebsocketUpgrader() *websocket.Upgrader {
return &websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
}

View File

@ -2,6 +2,7 @@ package server
import (
"github.com/gin-contrib/pprof"
"github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/utils/log"
"github.com/kercylan98/minotaur/utils/timer"
"google.golang.org/grpc"
@ -49,6 +50,19 @@ type runtime struct {
dispatcherBufferSize int // 消息分发器缓冲区大小
connWriteBufferSize int // 连接写入缓冲区大小
disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道
websocketUpgrader *websocket.Upgrader // websocket 升级器
}
// WithWebsocketUpgrade 通过指定 websocket.Upgrader 的方式创建服务器
// - 默认值为 DefaultWebsocketUpgrader
// - 该选项仅在创建 NetworkWebsocket 服务器时有效
func WithWebsocketUpgrade(upgrader *websocket.Upgrader) Option {
return func(srv *Server) {
if srv.network != NetworkWebsocket {
return
}
srv.websocketUpgrader = upgrader
}
}
// WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器

View File

@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/server/internal/logger"
"github.com/kercylan98/minotaur/utils/concurrent"
"github.com/kercylan98/minotaur/utils/log"
@ -196,12 +195,8 @@ func (slf *Server) Run(addr string) error {
go func(conn *Conn) {
defer func() {
if err := recover(); err != nil {
e, ok := err.(error)
if !ok {
e = fmt.Errorf("%v", err)
}
conn.Close(e)
if err := super.RecoverTransform(recover()); err != nil {
conn.Close(err)
}
}()
@ -254,16 +249,12 @@ func (slf *Server) Run(addr string) error {
pattern = addr[index:]
slf.addr = slf.addr[:index]
}
var upgrade = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
if slf.websocketUpgrader == nil {
slf.websocketUpgrader = DefaultWebsocketUpgrader()
}
http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) {
ip := request.Header.Get("X-Real-IP")
ws, err := upgrade.Upgrade(writer, request, nil)
ws, err := slf.websocketUpgrader.Upgrade(writer, request, nil)
if err != nil {
return
}
@ -289,12 +280,8 @@ func (slf *Server) Run(addr string) error {
slf.OnConnectionOpenedEvent(conn)
defer func() {
if err := recover(); err != nil {
e, ok := err.(error)
if !ok {
e = fmt.Errorf("%v", err)
}
conn.Close(e)
if err := super.RecoverTransform(recover()); err != nil {
conn.Close(err)
}
}()
for !conn.IsClosed() {
@ -734,15 +721,11 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) {
if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync {
defer func(msg *Message) {
super.Handle(cancel)
if err := recover(); err != nil {
if err := super.RecoverTransform(recover()); err != nil {
stack := string(debug.Stack())
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.String("Info", msg.String()), log.Any("error", err), log.String("stack", stack))
fmt.Println(stack)
e, ok := err.(error)
if !ok {
e = fmt.Errorf("%v", err)
}
slf.OnMessageErrorEvent(msg, e)
slf.OnMessageErrorEvent(msg, err)
}
if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback {
dispatcher.antiUnique(msg.name)
@ -782,18 +765,14 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) {
case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync:
if err := slf.ants.Submit(func() {
defer func() {
if err := recover(); err != nil {
if err := super.RecoverTransform(recover()); err != nil {
if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync {
dispatcher.antiUnique(msg.name)
}
stack := string(debug.Stack())
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack))
fmt.Println(stack)
e, ok := err.(error)
if !ok {
e = fmt.Errorf("%v", err)
}
slf.OnMessageErrorEvent(msg, e)
slf.OnMessageErrorEvent(msg, err)
}
super.Handle(cancel)
slf.low(msg, present, time.Second)

View File

@ -1,21 +1,21 @@
package server_test
import (
"fmt"
"github.com/kercylan98/minotaur/server"
"testing"
"time"
)
type TestService struct {
}
type TestService struct{}
func (ts *TestService) OnInit(srv *server.Server) {
srv.RegStartFinishEvent(func(srv *server.Server) {
println("Server started")
fmt.Println("server start finish")
})
srv.RegStopEvent(func(srv *server.Server) {
println("Server stopped")
fmt.Println("server stop")
})
}
@ -24,7 +24,20 @@ func TestBindService(t *testing.T) {
server.BindService(srv, new(TestService))
if err := srv.RunNone(); err != nil {
t.Fatal(err)
}
}
func ExampleBindService() {
srv := server.New(server.NetworkNone, server.WithLimitLife(time.Second))
server.BindService(srv, new(TestService))
if err := srv.RunNone(); err != nil {
panic(err)
}
// Output:
// server start finish
// server stop
}