feat: server 包新增 WithWebsocketUpgrade 函数,支持自定义 websocket.Upgrader
This commit is contained in:
parent
7efe88a0f4
commit
e960d07f49
|
@ -1,6 +1,8 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -17,3 +19,13 @@ const (
|
|||
DefaultDispatcherBufferSize = 1024 * 16
|
||||
DefaultConnWriteBufferSize = 1024 * 1
|
||||
)
|
||||
|
||||
func DefaultWebsocketUpgrader() *websocket.Upgrader {
|
||||
return &websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"github.com/gin-contrib/pprof"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/kercylan98/minotaur/utils/log"
|
||||
"github.com/kercylan98/minotaur/utils/timer"
|
||||
"google.golang.org/grpc"
|
||||
|
@ -49,6 +50,19 @@ type runtime struct {
|
|||
dispatcherBufferSize int // 消息分发器缓冲区大小
|
||||
connWriteBufferSize int // 连接写入缓冲区大小
|
||||
disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道
|
||||
websocketUpgrader *websocket.Upgrader // websocket 升级器
|
||||
}
|
||||
|
||||
// WithWebsocketUpgrade 通过指定 websocket.Upgrader 的方式创建服务器
|
||||
// - 默认值为 DefaultWebsocketUpgrader
|
||||
// - 该选项仅在创建 NetworkWebsocket 服务器时有效
|
||||
func WithWebsocketUpgrade(upgrader *websocket.Upgrader) Option {
|
||||
return func(srv *Server) {
|
||||
if srv.network != NetworkWebsocket {
|
||||
return
|
||||
}
|
||||
srv.websocketUpgrader = upgrader
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/kercylan98/minotaur/server/internal/logger"
|
||||
"github.com/kercylan98/minotaur/utils/concurrent"
|
||||
"github.com/kercylan98/minotaur/utils/log"
|
||||
|
@ -196,12 +195,8 @@ func (slf *Server) Run(addr string) error {
|
|||
|
||||
go func(conn *Conn) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
e, ok := err.(error)
|
||||
if !ok {
|
||||
e = fmt.Errorf("%v", err)
|
||||
}
|
||||
conn.Close(e)
|
||||
if err := super.RecoverTransform(recover()); err != nil {
|
||||
conn.Close(err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -254,16 +249,12 @@ func (slf *Server) Run(addr string) error {
|
|||
pattern = addr[index:]
|
||||
slf.addr = slf.addr[:index]
|
||||
}
|
||||
var upgrade = websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
if slf.websocketUpgrader == nil {
|
||||
slf.websocketUpgrader = DefaultWebsocketUpgrader()
|
||||
}
|
||||
http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) {
|
||||
ip := request.Header.Get("X-Real-IP")
|
||||
ws, err := upgrade.Upgrade(writer, request, nil)
|
||||
ws, err := slf.websocketUpgrader.Upgrade(writer, request, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -289,12 +280,8 @@ func (slf *Server) Run(addr string) error {
|
|||
slf.OnConnectionOpenedEvent(conn)
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
e, ok := err.(error)
|
||||
if !ok {
|
||||
e = fmt.Errorf("%v", err)
|
||||
}
|
||||
conn.Close(e)
|
||||
if err := super.RecoverTransform(recover()); err != nil {
|
||||
conn.Close(err)
|
||||
}
|
||||
}()
|
||||
for !conn.IsClosed() {
|
||||
|
@ -734,15 +721,11 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) {
|
|||
if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync {
|
||||
defer func(msg *Message) {
|
||||
super.Handle(cancel)
|
||||
if err := recover(); err != nil {
|
||||
if err := super.RecoverTransform(recover()); err != nil {
|
||||
stack := string(debug.Stack())
|
||||
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.String("Info", msg.String()), log.Any("error", err), log.String("stack", stack))
|
||||
fmt.Println(stack)
|
||||
e, ok := err.(error)
|
||||
if !ok {
|
||||
e = fmt.Errorf("%v", err)
|
||||
}
|
||||
slf.OnMessageErrorEvent(msg, e)
|
||||
slf.OnMessageErrorEvent(msg, err)
|
||||
}
|
||||
if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback {
|
||||
dispatcher.antiUnique(msg.name)
|
||||
|
@ -782,18 +765,14 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) {
|
|||
case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync:
|
||||
if err := slf.ants.Submit(func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
if err := super.RecoverTransform(recover()); err != nil {
|
||||
if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync {
|
||||
dispatcher.antiUnique(msg.name)
|
||||
}
|
||||
stack := string(debug.Stack())
|
||||
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack))
|
||||
fmt.Println(stack)
|
||||
e, ok := err.(error)
|
||||
if !ok {
|
||||
e = fmt.Errorf("%v", err)
|
||||
}
|
||||
slf.OnMessageErrorEvent(msg, e)
|
||||
slf.OnMessageErrorEvent(msg, err)
|
||||
}
|
||||
super.Handle(cancel)
|
||||
slf.low(msg, present, time.Second)
|
||||
|
|
|
@ -1,21 +1,21 @@
|
|||
package server_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/kercylan98/minotaur/server"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TestService struct {
|
||||
}
|
||||
type TestService struct{}
|
||||
|
||||
func (ts *TestService) OnInit(srv *server.Server) {
|
||||
srv.RegStartFinishEvent(func(srv *server.Server) {
|
||||
println("Server started")
|
||||
fmt.Println("server start finish")
|
||||
})
|
||||
|
||||
srv.RegStopEvent(func(srv *server.Server) {
|
||||
println("Server stopped")
|
||||
fmt.Println("server stop")
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,20 @@ func TestBindService(t *testing.T) {
|
|||
|
||||
server.BindService(srv, new(TestService))
|
||||
|
||||
if err := srv.RunNone(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleBindService() {
|
||||
srv := server.New(server.NetworkNone, server.WithLimitLife(time.Second))
|
||||
server.BindService(srv, new(TestService))
|
||||
|
||||
if err := srv.RunNone(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// server start finish
|
||||
// server stop
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue