From 75b912426bc66380d910666fe475b611444b9e25 Mon Sep 17 00:00:00 2001 From: liugq Date: Sat, 6 May 2023 14:53:07 +0800 Subject: [PATCH] add websocket proxy --- plugin/api/gateway/api.go | 39 +++++- plugin/api/gateway/websocket_proxy.go | 186 ++++++++++++++++++++++++++ 2 files changed, 224 insertions(+), 1 deletion(-) create mode 100644 plugin/api/gateway/websocket_proxy.go diff --git a/plugin/api/gateway/api.go b/plugin/api/gateway/api.go index 3034a3a8..6ef82a9e 100644 --- a/plugin/api/gateway/api.go +++ b/plugin/api/gateway/api.go @@ -5,8 +5,14 @@ package gateway import ( + "crypto/tls" "infini.sh/framework/core/api" "infini.sh/framework/core/api/rbac/enum" + "net" + "net/http" + "net/url" + log "github.com/cihub/seelog" + "time" ) type GatewayAPI struct { @@ -26,4 +32,35 @@ func InitAPI() { api.HandleAPIMethod(api.POST, "/gateway/instance/:instance_id/_proxy", gateway.RequirePermission(gateway.proxy, enum.PermissionGatewayInstanceRead)) api.HandleAPIMethod(api.GET, "/_platform/nodes", gateway.getExecutionNodes) -} + api.HandleAPIFunc("/ws_proxy", func(w http.ResponseWriter, req *http.Request) { + log.Debug(req.RequestURI) + endpoint := req.URL.Query().Get("endpoint") + path := req.URL.Query().Get("path") + var tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + target, err := url.Parse(endpoint) + if err != nil { + log.Error(err) + return + } + newURL, err := url.Parse(path) + if err != nil { + log.Error(err) + return + } + req.URL.Path = newURL.Path + req.URL.RawPath = newURL.RawPath + req.URL.RawQuery = "" + req.RequestURI = req.URL.RequestURI() + req.Header.Set("HOST", target.Host) + req.Host = target.Host + wsProxy := NewSingleHostReverseProxy(target) + wsProxy.Dial = (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial + wsProxy.TLSClientConfig = tlsConfig + wsProxy.ServeHTTP(w, req) + }) +} \ No newline at end of file diff --git a/plugin/api/gateway/websocket_proxy.go b/plugin/api/gateway/websocket_proxy.go new file mode 100644 index 00000000..99fe929e --- /dev/null +++ b/plugin/api/gateway/websocket_proxy.go @@ -0,0 +1,186 @@ +/* Copyright © INFINI Ltd. All rights reserved. + * Web: https://infinilabs.com + * Email: hello#infini.ltd */ + +package gateway + +import ( + "crypto/tls" + "io" + log "github.com/cihub/seelog" + "net" + "net/http" + "net/url" + "strings" +) + +// ReverseProxy is a WebSocket reverse proxy. It will not work with a regular +// HTTP request, so it is the caller's responsiblity to ensure the incoming +// request is a WebSocket request. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + Director func(*http.Request) + + // Dial specifies the dial function for dialing the proxied + // server over tcp. + // If Dial is nil, net.Dial is used. + Dial func(network, addr string) (net.Conn, error) + + // TLSClientConfig specifies the TLS configuration to use for 'wss'. + // If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // ErrorLog specifies an optional logger for errors + // that occur when attempting to proxy the request. + // If nil, logging goes to os.Stderr via the log package's + // standard logger. + ErrorLog *log.LoggerInterface +} + +// stolen from net/http/httputil. singleJoiningSlash ensures that the route +// '/a/' joined with '/b' becomes '/a/b'. +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new websocket ReverseProxy. The path +// rewrites follow the same rules as the httputil.ReverseProxy. If the target +// url has the path '/foo' and the incoming request '/bar', the request path +// will be updated to '/foo/bar' before forwarding. +// Scheme should specify if 'ws' or 'wss' should be used. +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + targetQuery := target.RawQuery + director := func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + } + return &ReverseProxy{Director: director} +} + +// Function to implement the http.Handler interface. +func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + logFunc := log.Errorf + + if !IsWebSocketRequest(r) { + http.Error(w, "Cannot handle non-WebSocket requests", 500) + logFunc("Received a request that was not a WebSocket request") + return + } + + outreq := new(http.Request) + // shallow copying + *outreq = *r + p.Director(outreq) + host := outreq.URL.Host + + if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) + } + + dial := p.Dial + if dial == nil { + dial = net.Dial + } + + // if host does not specify a port, use the default http port + if !strings.Contains(host, ":") { + if outreq.URL.Scheme == "wss" { + host = host + ":443" + } else { + host = host + ":80" + } + } + + if outreq.URL.Scheme == "wss" { + var tlsConfig *tls.Config + if p.TLSClientConfig == nil { + tlsConfig = &tls.Config{} + } else { + tlsConfig = p.TLSClientConfig + } + dial = func(network, address string) (net.Conn, error) { + return tls.Dial("tcp", host, tlsConfig) + } + } + + d, err := dial("tcp", host) + if err != nil { + http.Error(w, "Error forwarding request.", 500) + logFunc("Error dialing websocket backend %s: %v", outreq.URL, err) + return + } + // All request generated by the http package implement this interface. + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "Not a hijacker?", 500) + return + } + // Hijack() tells the http package not to do anything else with the connection. + // After, it bcomes this functions job to manage it. `nc` is of type *net.Conn. + nc, _, err := hj.Hijack() + if err != nil { + logFunc("Hijack error: %v", err) + return + } + defer nc.Close() // must close the underlying net connection after hijacking + defer d.Close() + + // write the modified incoming request to the dialed connection + err = outreq.Write(d) + if err != nil { + logFunc("Error copying request to target: %v", err) + return + } + errc := make(chan error, 2) + cp := func(dst io.Writer, src io.Reader) { + _, err := io.Copy(dst, src) + errc <- err + } + go cp(d, nc) + go cp(nc, d) + <-errc +} + +// IsWebSocketRequest returns a boolean indicating whether the request has the +// headers of a WebSocket handshake request. +func IsWebSocketRequest(r *http.Request) bool { + contains := func(key, val string) bool { + vv := strings.Split(r.Header.Get(key), ",") + for _, v := range vv { + if val == strings.ToLower(strings.TrimSpace(v)) { + return true + } + } + return false + } + if !contains("Connection", "upgrade") { + return false + } + if !contains("Upgrade", "websocket") { + return false + } + return true +}