other: 新 server 包调整
This commit is contained in:
@@ -1,18 +1,18 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newEventHandler(trafficker Trafficker) (handler *eventHandler, err error) {
|
||||
func newEventHandler(options *Options, trafficker Trafficker) (handler *eventHandler, err error) {
|
||||
var wp *ants.Pool
|
||||
if wp, err = ants.NewPool(ants.DefaultAntsPoolSize, ants.WithNonblocking(true)); err != nil {
|
||||
return
|
||||
}
|
||||
handler = &eventHandler{
|
||||
options: options,
|
||||
trafficker: trafficker,
|
||||
workerPool: wp,
|
||||
}
|
||||
@@ -21,17 +21,18 @@ func newEventHandler(trafficker Trafficker) (handler *eventHandler, err error) {
|
||||
|
||||
type (
|
||||
Trafficker interface {
|
||||
OnBoot() error
|
||||
OnBoot(options *Options) error
|
||||
OnTraffic(c gnet.Conn, packet []byte)
|
||||
}
|
||||
eventHandler struct {
|
||||
options *Options
|
||||
trafficker Trafficker
|
||||
workerPool *ants.Pool
|
||||
}
|
||||
)
|
||||
|
||||
func (e *eventHandler) OnBoot(eng gnet.Engine) (action gnet.Action) {
|
||||
if err := e.trafficker.OnBoot(); err != nil {
|
||||
if err := e.trafficker.OnBoot(e.options); err != nil {
|
||||
action = gnet.Shutdown
|
||||
}
|
||||
return
|
||||
@@ -46,7 +47,6 @@ func (e *eventHandler) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
|
||||
}
|
||||
|
||||
func (e *eventHandler) OnClose(c gnet.Conn, err error) (action gnet.Action) {
|
||||
fmt.Println("断开")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
4
server/v2/options.go
Normal file
4
server/v2/options.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package server
|
||||
|
||||
type Options struct {
|
||||
}
|
||||
@@ -15,6 +15,6 @@ type Server struct {
|
||||
|
||||
func (s *Server) Run(protoAddr string) (err error) {
|
||||
var handler *eventHandler
|
||||
handler, err = newEventHandler(s.trafficker)
|
||||
handler, err = newEventHandler(new(Options), s.trafficker)
|
||||
return gnet.Run(handler, protoAddr)
|
||||
}
|
||||
|
||||
29
server/v2/server_test.go
Normal file
29
server/v2/server_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/kercylan98/minotaur/server/v2"
|
||||
"github.com/kercylan98/minotaur/server/v2/traffickers"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.GET("/", func(context *gin.Context) {
|
||||
context.JSON(200, gin.H{
|
||||
"ping": "pong",
|
||||
})
|
||||
})
|
||||
srv := server.NewServer(traffickers.WebSocket(r, func(handler *gin.Engine, upgradeHandler func(writer http.ResponseWriter, request *http.Request) error) {
|
||||
handler.GET("/ws", func(context *gin.Context) {
|
||||
if err := upgradeHandler(context.Writer, context.Request); err != nil {
|
||||
context.AbortWithError(500, err)
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
if err := srv.Run("tcp://:8080"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -9,18 +9,22 @@ import (
|
||||
netHttp "net/http"
|
||||
)
|
||||
|
||||
func Http(handler netHttp.Handler) server.Trafficker {
|
||||
return &http{
|
||||
func Http[H netHttp.Handler](handler H) server.Trafficker {
|
||||
return &http[H]{
|
||||
handler: handler,
|
||||
ncb: func(c gnet.Conn, err error) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type http struct {
|
||||
handler netHttp.Handler
|
||||
type http[H netHttp.Handler] struct {
|
||||
handler H
|
||||
rwp *hub.ObjectPool[*httpResponseWriter]
|
||||
ncb func(c gnet.Conn, err error) error
|
||||
}
|
||||
|
||||
func (h *http) OnBoot() error {
|
||||
func (h *http[H]) OnBoot(options *server.Options) error {
|
||||
h.rwp = hub.NewObjectPool[httpResponseWriter](func() *httpResponseWriter {
|
||||
return new(httpResponseWriter)
|
||||
}, func(data *httpResponseWriter) {
|
||||
@@ -29,14 +33,24 @@ func (h *http) OnBoot() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *http) OnTraffic(c gnet.Conn, packet []byte) {
|
||||
func (h *http[H]) OnTraffic(c gnet.Conn, packet []byte) {
|
||||
var responseWriter *httpResponseWriter
|
||||
defer func() {
|
||||
if responseWriter == nil || !responseWriter.isHijack {
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
httpRequest, err := netHttp.ReadRequest(bufio.NewReader(bytes.NewReader(packet)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
responseWriter := h.rwp.Get()
|
||||
responseWriter = h.rwp.Get()
|
||||
responseWriter.init(c)
|
||||
|
||||
h.handler.ServeHTTP(responseWriter, httpRequest)
|
||||
if responseWriter.isHijack {
|
||||
return
|
||||
}
|
||||
_ = responseWriter.Result().Write(c)
|
||||
}
|
||||
|
||||
211
server/v2/traffickers/http_recorder.go
Normal file
211
server/v2/traffickers/http_recorder.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package traffickers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
"io"
|
||||
"net"
|
||||
netHttp "net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
type httpResponseWriter struct {
|
||||
Code int
|
||||
HeaderMap netHttp.Header
|
||||
Body *bytes.Buffer
|
||||
Flushed bool
|
||||
|
||||
conn *websocketConn
|
||||
result *netHttp.Response
|
||||
snapHeader netHttp.Header
|
||||
wroteHeader bool
|
||||
isHijack bool
|
||||
}
|
||||
|
||||
func (rw *httpResponseWriter) init(c gnet.Conn) {
|
||||
rw.conn = &websocketConn{Conn: c}
|
||||
rw.Code = 200
|
||||
rw.Body = new(bytes.Buffer)
|
||||
rw.HeaderMap = make(netHttp.Header)
|
||||
rw.isHijack = false
|
||||
}
|
||||
|
||||
func (rw *httpResponseWriter) reset() {
|
||||
rw.conn = nil
|
||||
rw.Code = 200
|
||||
rw.Body = nil
|
||||
rw.HeaderMap = nil
|
||||
rw.isHijack = false
|
||||
}
|
||||
|
||||
func (rw *httpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if !rw.isHijack {
|
||||
return rw.conn, bufio.NewReadWriter(bufio.NewReader(rw.conn), bufio.NewWriter(rw.conn)), nil
|
||||
}
|
||||
return nil, nil, netHttp.ErrHijacked
|
||||
}
|
||||
|
||||
func (rw *httpResponseWriter) Header() netHttp.Header {
|
||||
m := rw.HeaderMap
|
||||
if m == nil {
|
||||
m = make(netHttp.Header)
|
||||
rw.HeaderMap = m
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (rw *httpResponseWriter) writeHeader(b []byte, str string) {
|
||||
if rw.wroteHeader {
|
||||
return
|
||||
}
|
||||
if len(str) > 512 {
|
||||
str = str[:512]
|
||||
}
|
||||
|
||||
m := rw.Header()
|
||||
|
||||
_, hasType := m["Content-Type"]
|
||||
hasTE := m.Get("Transfer-Encoding") != ""
|
||||
if !hasType && !hasTE {
|
||||
if b == nil {
|
||||
b = []byte(str)
|
||||
}
|
||||
m.Set("Content-Type", netHttp.DetectContentType(b))
|
||||
}
|
||||
|
||||
rw.WriteHeader(200)
|
||||
}
|
||||
|
||||
func (rw *httpResponseWriter) Write(buf []byte) (n int, err error) {
|
||||
if rw.isHijack {
|
||||
n = len(buf)
|
||||
var wait = make(chan error)
|
||||
if err = rw.conn.AsyncWrite(buf, func(c gnet.Conn, err error) error {
|
||||
if err != nil {
|
||||
wait <- err
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
err = <-wait
|
||||
return
|
||||
}
|
||||
rw.writeHeader(buf, "")
|
||||
if rw.Body != nil {
|
||||
rw.Body.Write(buf)
|
||||
}
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
func (rw *httpResponseWriter) WriteString(str string) (int, error) {
|
||||
rw.writeHeader(nil, str)
|
||||
if rw.Body != nil {
|
||||
rw.Body.WriteString(str)
|
||||
}
|
||||
return len(str), nil
|
||||
}
|
||||
|
||||
func checkWriteHeaderCode(code int) {
|
||||
if code < 100 || code > 999 {
|
||||
panic(fmt.Sprintf("invalid WriteHeader code %v", code))
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *httpResponseWriter) WriteHeader(code int) {
|
||||
if rw.wroteHeader {
|
||||
return
|
||||
}
|
||||
|
||||
checkWriteHeaderCode(code)
|
||||
rw.Code = code
|
||||
rw.wroteHeader = true
|
||||
if rw.HeaderMap == nil {
|
||||
rw.HeaderMap = make(netHttp.Header)
|
||||
}
|
||||
rw.snapHeader = rw.HeaderMap.Clone()
|
||||
}
|
||||
|
||||
func (rw *httpResponseWriter) Flush() {
|
||||
if !rw.wroteHeader {
|
||||
rw.WriteHeader(200)
|
||||
}
|
||||
rw.Flushed = true
|
||||
}
|
||||
|
||||
func (rw *httpResponseWriter) Result() *netHttp.Response {
|
||||
if rw.result != nil {
|
||||
return rw.result
|
||||
}
|
||||
if rw.snapHeader == nil {
|
||||
rw.snapHeader = rw.HeaderMap.Clone()
|
||||
}
|
||||
res := &netHttp.Response{
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
StatusCode: rw.Code,
|
||||
Header: rw.snapHeader,
|
||||
}
|
||||
rw.result = res
|
||||
if res.StatusCode == 0 {
|
||||
res.StatusCode = 200
|
||||
}
|
||||
res.Status = fmt.Sprintf("%03d %s", res.StatusCode, netHttp.StatusText(res.StatusCode))
|
||||
if rw.Body != nil {
|
||||
res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
|
||||
} else {
|
||||
res.Body = netHttp.NoBody
|
||||
}
|
||||
res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
|
||||
|
||||
if trailers, ok := rw.snapHeader["Trailer"]; ok {
|
||||
res.Trailer = make(netHttp.Header, len(trailers))
|
||||
for _, k := range trailers {
|
||||
for _, k := range strings.Split(k, ",") {
|
||||
k = netHttp.CanonicalHeaderKey(textproto.TrimString(k))
|
||||
if !httpguts.ValidTrailerHeader(k) {
|
||||
// Ignore since forbidden by RFC 7230, section 4.1.2.
|
||||
continue
|
||||
}
|
||||
vv, ok := rw.HeaderMap[k]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
vv2 := make([]string, len(vv))
|
||||
copy(vv2, vv)
|
||||
res.Trailer[k] = vv2
|
||||
}
|
||||
}
|
||||
}
|
||||
for k, vv := range rw.HeaderMap {
|
||||
if !strings.HasPrefix(k, netHttp.TrailerPrefix) {
|
||||
continue
|
||||
}
|
||||
if res.Trailer == nil {
|
||||
res.Trailer = make(netHttp.Header)
|
||||
}
|
||||
for _, v := range vv {
|
||||
res.Trailer.Add(strings.TrimPrefix(k, netHttp.TrailerPrefix), v)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func parseContentLength(cl string) int64 {
|
||||
cl = textproto.TrimString(cl)
|
||||
if cl == "" {
|
||||
return -1
|
||||
}
|
||||
n, err := strconv.ParseUint(cl, 10, 63)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
return int64(n)
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
package traffickers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
netHttp "net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type httpResponseWriter struct {
|
||||
c gnet.Conn
|
||||
statusCode int
|
||||
header netHttp.Header
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) init(c gnet.Conn) {
|
||||
w.c = c
|
||||
w.statusCode = 200
|
||||
w.header = make(netHttp.Header)
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) reset() {
|
||||
w.c = nil
|
||||
w.statusCode = 200
|
||||
w.header = nil
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) Header() netHttp.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) Write(b []byte) (n int, err error) {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("HTTP/1.1 ")
|
||||
buf.WriteString(netHttp.StatusText(w.statusCode))
|
||||
buf.WriteString("\r\n")
|
||||
w.header.Set("Content-Length", strconv.Itoa(len(b)))
|
||||
if err = w.header.Write(&buf); err != nil {
|
||||
return
|
||||
}
|
||||
buf.WriteString("\r\n")
|
||||
buf.Write(b)
|
||||
res := buf.Bytes()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
err = w.c.AsyncWrite(res, func(c gnet.Conn, e error) error {
|
||||
err = e
|
||||
wg.Done()
|
||||
return nil
|
||||
})
|
||||
wg.Wait()
|
||||
return len(res), err
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
}
|
||||
64
server/v2/traffickers/websocket.go
Normal file
64
server/v2/traffickers/websocket.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package traffickers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
ws "github.com/gorilla/websocket"
|
||||
"github.com/kercylan98/minotaur/server/v2"
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
netHttp "net/http"
|
||||
)
|
||||
|
||||
func WebSocket[H netHttp.Handler](handler H, binder func(handler H, upgradeHandler func(writer netHttp.ResponseWriter, request *netHttp.Request) error)) server.Trafficker {
|
||||
w := &websocket[H]{
|
||||
http: Http(handler).(*http[H]),
|
||||
binder: binder,
|
||||
upgrader: &ws.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *netHttp.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
}
|
||||
binder(handler, w.OnUpgrade)
|
||||
return w
|
||||
}
|
||||
|
||||
type websocket[H netHttp.Handler] struct {
|
||||
*http[H]
|
||||
binder func(handler H, upgradeHandler func(writer netHttp.ResponseWriter, request *netHttp.Request) error)
|
||||
upgrader *ws.Upgrader
|
||||
}
|
||||
|
||||
func (w *websocket[H]) OnBoot(options *server.Options) error {
|
||||
return w.http.OnBoot(options)
|
||||
}
|
||||
|
||||
func (w *websocket[H]) OnTraffic(c gnet.Conn, packet []byte) {
|
||||
w.http.OnTraffic(c, packet)
|
||||
}
|
||||
|
||||
func (w *websocket[H]) OnUpgrade(writer netHttp.ResponseWriter, request *netHttp.Request) (err error) {
|
||||
var (
|
||||
ip string
|
||||
conn *ws.Conn
|
||||
)
|
||||
|
||||
ip = request.Header.Get("X-Real-IP")
|
||||
conn, err = w.upgrader.Upgrade(writer, request, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("opened", ip)
|
||||
go func() {
|
||||
for {
|
||||
mt, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.WriteMessage(mt, data)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
16
server/v2/traffickers/websocket_conn.go
Normal file
16
server/v2/traffickers/websocket_conn.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package traffickers
|
||||
|
||||
import (
|
||||
"github.com/panjf2000/gnet/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
type websocketConn struct {
|
||||
gnet.Conn
|
||||
deadline time.Time
|
||||
}
|
||||
|
||||
func (c *websocketConn) SetDeadline(t time.Time) error {
|
||||
c.deadline = t
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user