feat: server 包新增 WithWebsocketUpgrade 函数,支持自定义 websocket.Upgrader
This commit is contained in:
parent
7efe88a0f4
commit
e960d07f49
|
@ -1,6 +1,8 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,3 +19,13 @@ const (
|
||||||
DefaultDispatcherBufferSize = 1024 * 16
|
DefaultDispatcherBufferSize = 1024 * 16
|
||||||
DefaultConnWriteBufferSize = 1024 * 1
|
DefaultConnWriteBufferSize = 1024 * 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func DefaultWebsocketUpgrader() *websocket.Upgrader {
|
||||||
|
return &websocket.Upgrader{
|
||||||
|
ReadBufferSize: 4096,
|
||||||
|
WriteBufferSize: 4096,
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-contrib/pprof"
|
"github.com/gin-contrib/pprof"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"github.com/kercylan98/minotaur/utils/log"
|
"github.com/kercylan98/minotaur/utils/log"
|
||||||
"github.com/kercylan98/minotaur/utils/timer"
|
"github.com/kercylan98/minotaur/utils/timer"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
@ -49,6 +50,19 @@ type runtime struct {
|
||||||
dispatcherBufferSize int // 消息分发器缓冲区大小
|
dispatcherBufferSize int // 消息分发器缓冲区大小
|
||||||
connWriteBufferSize int // 连接写入缓冲区大小
|
connWriteBufferSize int // 连接写入缓冲区大小
|
||||||
disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道
|
disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道
|
||||||
|
websocketUpgrader *websocket.Upgrader // websocket 升级器
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithWebsocketUpgrade 通过指定 websocket.Upgrader 的方式创建服务器
|
||||||
|
// - 默认值为 DefaultWebsocketUpgrader
|
||||||
|
// - 该选项仅在创建 NetworkWebsocket 服务器时有效
|
||||||
|
func WithWebsocketUpgrade(upgrader *websocket.Upgrader) Option {
|
||||||
|
return func(srv *Server) {
|
||||||
|
if srv.network != NetworkWebsocket {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srv.websocketUpgrader = upgrader
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器
|
// WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"github.com/kercylan98/minotaur/server/internal/logger"
|
"github.com/kercylan98/minotaur/server/internal/logger"
|
||||||
"github.com/kercylan98/minotaur/utils/concurrent"
|
"github.com/kercylan98/minotaur/utils/concurrent"
|
||||||
"github.com/kercylan98/minotaur/utils/log"
|
"github.com/kercylan98/minotaur/utils/log"
|
||||||
|
@ -196,12 +195,8 @@ func (slf *Server) Run(addr string) error {
|
||||||
|
|
||||||
go func(conn *Conn) {
|
go func(conn *Conn) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := super.RecoverTransform(recover()); err != nil {
|
||||||
e, ok := err.(error)
|
conn.Close(err)
|
||||||
if !ok {
|
|
||||||
e = fmt.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
conn.Close(e)
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -254,16 +249,12 @@ func (slf *Server) Run(addr string) error {
|
||||||
pattern = addr[index:]
|
pattern = addr[index:]
|
||||||
slf.addr = slf.addr[:index]
|
slf.addr = slf.addr[:index]
|
||||||
}
|
}
|
||||||
var upgrade = websocket.Upgrader{
|
if slf.websocketUpgrader == nil {
|
||||||
ReadBufferSize: 4096,
|
slf.websocketUpgrader = DefaultWebsocketUpgrader()
|
||||||
WriteBufferSize: 4096,
|
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
|
||||||
return true
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) {
|
http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) {
|
||||||
ip := request.Header.Get("X-Real-IP")
|
ip := request.Header.Get("X-Real-IP")
|
||||||
ws, err := upgrade.Upgrade(writer, request, nil)
|
ws, err := slf.websocketUpgrader.Upgrade(writer, request, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -289,12 +280,8 @@ func (slf *Server) Run(addr string) error {
|
||||||
slf.OnConnectionOpenedEvent(conn)
|
slf.OnConnectionOpenedEvent(conn)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := super.RecoverTransform(recover()); err != nil {
|
||||||
e, ok := err.(error)
|
conn.Close(err)
|
||||||
if !ok {
|
|
||||||
e = fmt.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
conn.Close(e)
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for !conn.IsClosed() {
|
for !conn.IsClosed() {
|
||||||
|
@ -734,15 +721,11 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) {
|
||||||
if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync {
|
if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync {
|
||||||
defer func(msg *Message) {
|
defer func(msg *Message) {
|
||||||
super.Handle(cancel)
|
super.Handle(cancel)
|
||||||
if err := recover(); err != nil {
|
if err := super.RecoverTransform(recover()); err != nil {
|
||||||
stack := string(debug.Stack())
|
stack := string(debug.Stack())
|
||||||
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.String("Info", msg.String()), log.Any("error", err), log.String("stack", stack))
|
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.String("Info", msg.String()), log.Any("error", err), log.String("stack", stack))
|
||||||
fmt.Println(stack)
|
fmt.Println(stack)
|
||||||
e, ok := err.(error)
|
slf.OnMessageErrorEvent(msg, err)
|
||||||
if !ok {
|
|
||||||
e = fmt.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
slf.OnMessageErrorEvent(msg, e)
|
|
||||||
}
|
}
|
||||||
if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback {
|
if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback {
|
||||||
dispatcher.antiUnique(msg.name)
|
dispatcher.antiUnique(msg.name)
|
||||||
|
@ -782,18 +765,14 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) {
|
||||||
case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync:
|
case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync:
|
||||||
if err := slf.ants.Submit(func() {
|
if err := slf.ants.Submit(func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := super.RecoverTransform(recover()); err != nil {
|
||||||
if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync {
|
if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync {
|
||||||
dispatcher.antiUnique(msg.name)
|
dispatcher.antiUnique(msg.name)
|
||||||
}
|
}
|
||||||
stack := string(debug.Stack())
|
stack := string(debug.Stack())
|
||||||
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack))
|
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack))
|
||||||
fmt.Println(stack)
|
fmt.Println(stack)
|
||||||
e, ok := err.(error)
|
slf.OnMessageErrorEvent(msg, err)
|
||||||
if !ok {
|
|
||||||
e = fmt.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
slf.OnMessageErrorEvent(msg, e)
|
|
||||||
}
|
}
|
||||||
super.Handle(cancel)
|
super.Handle(cancel)
|
||||||
slf.low(msg, present, time.Second)
|
slf.low(msg, present, time.Second)
|
||||||
|
|
|
@ -1,21 +1,21 @@
|
||||||
package server_test
|
package server_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/kercylan98/minotaur/server"
|
"github.com/kercylan98/minotaur/server"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TestService struct {
|
type TestService struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (ts *TestService) OnInit(srv *server.Server) {
|
func (ts *TestService) OnInit(srv *server.Server) {
|
||||||
srv.RegStartFinishEvent(func(srv *server.Server) {
|
srv.RegStartFinishEvent(func(srv *server.Server) {
|
||||||
println("Server started")
|
fmt.Println("server start finish")
|
||||||
})
|
})
|
||||||
|
|
||||||
srv.RegStopEvent(func(srv *server.Server) {
|
srv.RegStopEvent(func(srv *server.Server) {
|
||||||
println("Server stopped")
|
fmt.Println("server stop")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,7 +24,20 @@ func TestBindService(t *testing.T) {
|
||||||
|
|
||||||
server.BindService(srv, new(TestService))
|
server.BindService(srv, new(TestService))
|
||||||
|
|
||||||
|
if err := srv.RunNone(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleBindService() {
|
||||||
|
srv := server.New(server.NetworkNone, server.WithLimitLife(time.Second))
|
||||||
|
server.BindService(srv, new(TestService))
|
||||||
|
|
||||||
if err := srv.RunNone(); err != nil {
|
if err := srv.RunNone(); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// server start finish
|
||||||
|
// server stop
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue