other: 新 server 包调整

This commit is contained in:
kercylan
2024-03-20 00:13:31 +08:00
parent 37f35aa602
commit 7239a278ee
11 changed files with 354 additions and 91 deletions

View File

@@ -1,18 +1,18 @@
package server
import (
"fmt"
"github.com/panjf2000/ants/v2"
"github.com/panjf2000/gnet/v2"
"time"
)
func newEventHandler(trafficker Trafficker) (handler *eventHandler, err error) {
func newEventHandler(options *Options, trafficker Trafficker) (handler *eventHandler, err error) {
var wp *ants.Pool
if wp, err = ants.NewPool(ants.DefaultAntsPoolSize, ants.WithNonblocking(true)); err != nil {
return
}
handler = &eventHandler{
options: options,
trafficker: trafficker,
workerPool: wp,
}
@@ -21,17 +21,18 @@ func newEventHandler(trafficker Trafficker) (handler *eventHandler, err error) {
type (
Trafficker interface {
OnBoot() error
OnBoot(options *Options) error
OnTraffic(c gnet.Conn, packet []byte)
}
eventHandler struct {
options *Options
trafficker Trafficker
workerPool *ants.Pool
}
)
func (e *eventHandler) OnBoot(eng gnet.Engine) (action gnet.Action) {
if err := e.trafficker.OnBoot(); err != nil {
if err := e.trafficker.OnBoot(e.options); err != nil {
action = gnet.Shutdown
}
return
@@ -46,7 +47,6 @@ func (e *eventHandler) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
}
func (e *eventHandler) OnClose(c gnet.Conn, err error) (action gnet.Action) {
fmt.Println("断开")
return
}

4
server/v2/options.go Normal file
View File

@@ -0,0 +1,4 @@
package server
type Options struct {
}

View File

@@ -15,6 +15,6 @@ type Server struct {
func (s *Server) Run(protoAddr string) (err error) {
var handler *eventHandler
handler, err = newEventHandler(s.trafficker)
handler, err = newEventHandler(new(Options), s.trafficker)
return gnet.Run(handler, protoAddr)
}

29
server/v2/server_test.go Normal file
View File

@@ -0,0 +1,29 @@
package server_test
import (
"github.com/gin-gonic/gin"
"github.com/kercylan98/minotaur/server/v2"
"github.com/kercylan98/minotaur/server/v2/traffickers"
"net/http"
"testing"
)
func TestNewServer(t *testing.T) {
r := gin.New()
r.GET("/", func(context *gin.Context) {
context.JSON(200, gin.H{
"ping": "pong",
})
})
srv := server.NewServer(traffickers.WebSocket(r, func(handler *gin.Engine, upgradeHandler func(writer http.ResponseWriter, request *http.Request) error) {
handler.GET("/ws", func(context *gin.Context) {
if err := upgradeHandler(context.Writer, context.Request); err != nil {
context.AbortWithError(500, err)
}
})
}))
if err := srv.Run("tcp://:8080"); err != nil {
panic(err)
}
}

View File

@@ -9,18 +9,22 @@ import (
netHttp "net/http"
)
func Http(handler netHttp.Handler) server.Trafficker {
return &http{
func Http[H netHttp.Handler](handler H) server.Trafficker {
return &http[H]{
handler: handler,
ncb: func(c gnet.Conn, err error) error {
return nil
},
}
}
type http struct {
handler netHttp.Handler
type http[H netHttp.Handler] struct {
handler H
rwp *hub.ObjectPool[*httpResponseWriter]
ncb func(c gnet.Conn, err error) error
}
func (h *http) OnBoot() error {
func (h *http[H]) OnBoot(options *server.Options) error {
h.rwp = hub.NewObjectPool[httpResponseWriter](func() *httpResponseWriter {
return new(httpResponseWriter)
}, func(data *httpResponseWriter) {
@@ -29,14 +33,24 @@ func (h *http) OnBoot() error {
return nil
}
func (h *http) OnTraffic(c gnet.Conn, packet []byte) {
func (h *http[H]) OnTraffic(c gnet.Conn, packet []byte) {
var responseWriter *httpResponseWriter
defer func() {
if responseWriter == nil || !responseWriter.isHijack {
_ = c.Close()
}
}()
httpRequest, err := netHttp.ReadRequest(bufio.NewReader(bytes.NewReader(packet)))
if err != nil {
return
}
responseWriter := h.rwp.Get()
responseWriter = h.rwp.Get()
responseWriter.init(c)
h.handler.ServeHTTP(responseWriter, httpRequest)
if responseWriter.isHijack {
return
}
_ = responseWriter.Result().Write(c)
}

View File

@@ -0,0 +1,211 @@
package traffickers
import (
"bufio"
"bytes"
"fmt"
"github.com/panjf2000/gnet/v2"
"io"
"net"
netHttp "net/http"
"net/textproto"
"strconv"
"strings"
"golang.org/x/net/http/httpguts"
)
type httpResponseWriter struct {
Code int
HeaderMap netHttp.Header
Body *bytes.Buffer
Flushed bool
conn *websocketConn
result *netHttp.Response
snapHeader netHttp.Header
wroteHeader bool
isHijack bool
}
func (rw *httpResponseWriter) init(c gnet.Conn) {
rw.conn = &websocketConn{Conn: c}
rw.Code = 200
rw.Body = new(bytes.Buffer)
rw.HeaderMap = make(netHttp.Header)
rw.isHijack = false
}
func (rw *httpResponseWriter) reset() {
rw.conn = nil
rw.Code = 200
rw.Body = nil
rw.HeaderMap = nil
rw.isHijack = false
}
func (rw *httpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if !rw.isHijack {
return rw.conn, bufio.NewReadWriter(bufio.NewReader(rw.conn), bufio.NewWriter(rw.conn)), nil
}
return nil, nil, netHttp.ErrHijacked
}
func (rw *httpResponseWriter) Header() netHttp.Header {
m := rw.HeaderMap
if m == nil {
m = make(netHttp.Header)
rw.HeaderMap = m
}
return m
}
func (rw *httpResponseWriter) writeHeader(b []byte, str string) {
if rw.wroteHeader {
return
}
if len(str) > 512 {
str = str[:512]
}
m := rw.Header()
_, hasType := m["Content-Type"]
hasTE := m.Get("Transfer-Encoding") != ""
if !hasType && !hasTE {
if b == nil {
b = []byte(str)
}
m.Set("Content-Type", netHttp.DetectContentType(b))
}
rw.WriteHeader(200)
}
func (rw *httpResponseWriter) Write(buf []byte) (n int, err error) {
if rw.isHijack {
n = len(buf)
var wait = make(chan error)
if err = rw.conn.AsyncWrite(buf, func(c gnet.Conn, err error) error {
if err != nil {
wait <- err
}
return nil
}); err != nil {
return
}
err = <-wait
return
}
rw.writeHeader(buf, "")
if rw.Body != nil {
rw.Body.Write(buf)
}
return len(buf), nil
}
func (rw *httpResponseWriter) WriteString(str string) (int, error) {
rw.writeHeader(nil, str)
if rw.Body != nil {
rw.Body.WriteString(str)
}
return len(str), nil
}
func checkWriteHeaderCode(code int) {
if code < 100 || code > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", code))
}
}
func (rw *httpResponseWriter) WriteHeader(code int) {
if rw.wroteHeader {
return
}
checkWriteHeaderCode(code)
rw.Code = code
rw.wroteHeader = true
if rw.HeaderMap == nil {
rw.HeaderMap = make(netHttp.Header)
}
rw.snapHeader = rw.HeaderMap.Clone()
}
func (rw *httpResponseWriter) Flush() {
if !rw.wroteHeader {
rw.WriteHeader(200)
}
rw.Flushed = true
}
func (rw *httpResponseWriter) Result() *netHttp.Response {
if rw.result != nil {
return rw.result
}
if rw.snapHeader == nil {
rw.snapHeader = rw.HeaderMap.Clone()
}
res := &netHttp.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: rw.Code,
Header: rw.snapHeader,
}
rw.result = res
if res.StatusCode == 0 {
res.StatusCode = 200
}
res.Status = fmt.Sprintf("%03d %s", res.StatusCode, netHttp.StatusText(res.StatusCode))
if rw.Body != nil {
res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
} else {
res.Body = netHttp.NoBody
}
res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
if trailers, ok := rw.snapHeader["Trailer"]; ok {
res.Trailer = make(netHttp.Header, len(trailers))
for _, k := range trailers {
for _, k := range strings.Split(k, ",") {
k = netHttp.CanonicalHeaderKey(textproto.TrimString(k))
if !httpguts.ValidTrailerHeader(k) {
// Ignore since forbidden by RFC 7230, section 4.1.2.
continue
}
vv, ok := rw.HeaderMap[k]
if !ok {
continue
}
vv2 := make([]string, len(vv))
copy(vv2, vv)
res.Trailer[k] = vv2
}
}
}
for k, vv := range rw.HeaderMap {
if !strings.HasPrefix(k, netHttp.TrailerPrefix) {
continue
}
if res.Trailer == nil {
res.Trailer = make(netHttp.Header)
}
for _, v := range vv {
res.Trailer.Add(strings.TrimPrefix(k, netHttp.TrailerPrefix), v)
}
}
return res
}
func parseContentLength(cl string) int64 {
cl = textproto.TrimString(cl)
if cl == "" {
return -1
}
n, err := strconv.ParseUint(cl, 10, 63)
if err != nil {
return -1
}
return int64(n)
}

View File

@@ -1,58 +0,0 @@
package traffickers
import (
"bytes"
"github.com/panjf2000/gnet/v2"
netHttp "net/http"
"strconv"
"sync"
)
type httpResponseWriter struct {
c gnet.Conn
statusCode int
header netHttp.Header
}
func (w *httpResponseWriter) init(c gnet.Conn) {
w.c = c
w.statusCode = 200
w.header = make(netHttp.Header)
}
func (w *httpResponseWriter) reset() {
w.c = nil
w.statusCode = 200
w.header = nil
}
func (w *httpResponseWriter) Header() netHttp.Header {
return w.header
}
func (w *httpResponseWriter) Write(b []byte) (n int, err error) {
var buf bytes.Buffer
buf.WriteString("HTTP/1.1 ")
buf.WriteString(netHttp.StatusText(w.statusCode))
buf.WriteString("\r\n")
w.header.Set("Content-Length", strconv.Itoa(len(b)))
if err = w.header.Write(&buf); err != nil {
return
}
buf.WriteString("\r\n")
buf.Write(b)
res := buf.Bytes()
var wg sync.WaitGroup
wg.Add(1)
err = w.c.AsyncWrite(res, func(c gnet.Conn, e error) error {
err = e
wg.Done()
return nil
})
wg.Wait()
return len(res), err
}
func (w *httpResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}

View File

@@ -0,0 +1,64 @@
package traffickers
import (
"fmt"
ws "github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/server/v2"
"github.com/panjf2000/gnet/v2"
netHttp "net/http"
)
func WebSocket[H netHttp.Handler](handler H, binder func(handler H, upgradeHandler func(writer netHttp.ResponseWriter, request *netHttp.Request) error)) server.Trafficker {
w := &websocket[H]{
http: Http(handler).(*http[H]),
binder: binder,
upgrader: &ws.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *netHttp.Request) bool {
return true
},
},
}
binder(handler, w.OnUpgrade)
return w
}
type websocket[H netHttp.Handler] struct {
*http[H]
binder func(handler H, upgradeHandler func(writer netHttp.ResponseWriter, request *netHttp.Request) error)
upgrader *ws.Upgrader
}
func (w *websocket[H]) OnBoot(options *server.Options) error {
return w.http.OnBoot(options)
}
func (w *websocket[H]) OnTraffic(c gnet.Conn, packet []byte) {
w.http.OnTraffic(c, packet)
}
func (w *websocket[H]) OnUpgrade(writer netHttp.ResponseWriter, request *netHttp.Request) (err error) {
var (
ip string
conn *ws.Conn
)
ip = request.Header.Get("X-Real-IP")
conn, err = w.upgrader.Upgrade(writer, request, nil)
if err != nil {
return
}
fmt.Println("opened", ip)
go func() {
for {
mt, data, err := conn.ReadMessage()
if err != nil {
continue
}
conn.WriteMessage(mt, data)
}
}()
return nil
}

View File

@@ -0,0 +1,16 @@
package traffickers
import (
"github.com/panjf2000/gnet/v2"
"time"
)
type websocketConn struct {
gnet.Conn
deadline time.Time
}
func (c *websocketConn) SetDeadline(t time.Time) error {
c.deadline = t
return nil
}