add websocket proxy
This commit is contained in:
parent
62b5b94c56
commit
75b912426b
|
@ -5,8 +5,14 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"infini.sh/framework/core/api"
|
||||
"infini.sh/framework/core/api/rbac/enum"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
log "github.com/cihub/seelog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type GatewayAPI struct {
|
||||
|
@ -26,4 +32,35 @@ func InitAPI() {
|
|||
api.HandleAPIMethod(api.POST, "/gateway/instance/:instance_id/_proxy", gateway.RequirePermission(gateway.proxy, enum.PermissionGatewayInstanceRead))
|
||||
|
||||
api.HandleAPIMethod(api.GET, "/_platform/nodes", gateway.getExecutionNodes)
|
||||
}
|
||||
api.HandleAPIFunc("/ws_proxy", func(w http.ResponseWriter, req *http.Request) {
|
||||
log.Debug(req.RequestURI)
|
||||
endpoint := req.URL.Query().Get("endpoint")
|
||||
path := req.URL.Query().Get("path")
|
||||
var tlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
target, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
newURL, err := url.Parse(path)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
req.URL.Path = newURL.Path
|
||||
req.URL.RawPath = newURL.RawPath
|
||||
req.URL.RawQuery = ""
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
req.Header.Set("HOST", target.Host)
|
||||
req.Host = target.Host
|
||||
wsProxy := NewSingleHostReverseProxy(target)
|
||||
wsProxy.Dial = (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial
|
||||
wsProxy.TLSClientConfig = tlsConfig
|
||||
wsProxy.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,186 @@
|
|||
/* Copyright © INFINI Ltd. All rights reserved.
|
||||
* Web: https://infinilabs.com
|
||||
* Email: hello#infini.ltd */
|
||||
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
log "github.com/cihub/seelog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ReverseProxy is a WebSocket reverse proxy. It will not work with a regular
|
||||
// HTTP request, so it is the caller's responsiblity to ensure the incoming
|
||||
// request is a WebSocket request.
|
||||
type ReverseProxy struct {
|
||||
// Director must be a function which modifies
|
||||
// the request into a new request to be sent
|
||||
// using Transport. Its response is then copied
|
||||
// back to the original client unmodified.
|
||||
Director func(*http.Request)
|
||||
|
||||
// Dial specifies the dial function for dialing the proxied
|
||||
// server over tcp.
|
||||
// If Dial is nil, net.Dial is used.
|
||||
Dial func(network, addr string) (net.Conn, error)
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use for 'wss'.
|
||||
// If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
// ErrorLog specifies an optional logger for errors
|
||||
// that occur when attempting to proxy the request.
|
||||
// If nil, logging goes to os.Stderr via the log package's
|
||||
// standard logger.
|
||||
ErrorLog *log.LoggerInterface
|
||||
}
|
||||
|
||||
// stolen from net/http/httputil. singleJoiningSlash ensures that the route
|
||||
// '/a/' joined with '/b' becomes '/a/b'.
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
// NewSingleHostReverseProxy returns a new websocket ReverseProxy. The path
|
||||
// rewrites follow the same rules as the httputil.ReverseProxy. If the target
|
||||
// url has the path '/foo' and the incoming request '/bar', the request path
|
||||
// will be updated to '/foo/bar' before forwarding.
|
||||
// Scheme should specify if 'ws' or 'wss' should be used.
|
||||
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
|
||||
targetQuery := target.RawQuery
|
||||
director := func(req *http.Request) {
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
||||
}
|
||||
}
|
||||
return &ReverseProxy{Director: director}
|
||||
}
|
||||
|
||||
// Function to implement the http.Handler interface.
|
||||
func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
logFunc := log.Errorf
|
||||
|
||||
if !IsWebSocketRequest(r) {
|
||||
http.Error(w, "Cannot handle non-WebSocket requests", 500)
|
||||
logFunc("Received a request that was not a WebSocket request")
|
||||
return
|
||||
}
|
||||
|
||||
outreq := new(http.Request)
|
||||
// shallow copying
|
||||
*outreq = *r
|
||||
p.Director(outreq)
|
||||
host := outreq.URL.Host
|
||||
|
||||
if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
// If we aren't the first proxy retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
outreq.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
|
||||
dial := p.Dial
|
||||
if dial == nil {
|
||||
dial = net.Dial
|
||||
}
|
||||
|
||||
// if host does not specify a port, use the default http port
|
||||
if !strings.Contains(host, ":") {
|
||||
if outreq.URL.Scheme == "wss" {
|
||||
host = host + ":443"
|
||||
} else {
|
||||
host = host + ":80"
|
||||
}
|
||||
}
|
||||
|
||||
if outreq.URL.Scheme == "wss" {
|
||||
var tlsConfig *tls.Config
|
||||
if p.TLSClientConfig == nil {
|
||||
tlsConfig = &tls.Config{}
|
||||
} else {
|
||||
tlsConfig = p.TLSClientConfig
|
||||
}
|
||||
dial = func(network, address string) (net.Conn, error) {
|
||||
return tls.Dial("tcp", host, tlsConfig)
|
||||
}
|
||||
}
|
||||
|
||||
d, err := dial("tcp", host)
|
||||
if err != nil {
|
||||
http.Error(w, "Error forwarding request.", 500)
|
||||
logFunc("Error dialing websocket backend %s: %v", outreq.URL, err)
|
||||
return
|
||||
}
|
||||
// All request generated by the http package implement this interface.
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "Not a hijacker?", 500)
|
||||
return
|
||||
}
|
||||
// Hijack() tells the http package not to do anything else with the connection.
|
||||
// After, it bcomes this functions job to manage it. `nc` is of type *net.Conn.
|
||||
nc, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
logFunc("Hijack error: %v", err)
|
||||
return
|
||||
}
|
||||
defer nc.Close() // must close the underlying net connection after hijacking
|
||||
defer d.Close()
|
||||
|
||||
// write the modified incoming request to the dialed connection
|
||||
err = outreq.Write(d)
|
||||
if err != nil {
|
||||
logFunc("Error copying request to target: %v", err)
|
||||
return
|
||||
}
|
||||
errc := make(chan error, 2)
|
||||
cp := func(dst io.Writer, src io.Reader) {
|
||||
_, err := io.Copy(dst, src)
|
||||
errc <- err
|
||||
}
|
||||
go cp(d, nc)
|
||||
go cp(nc, d)
|
||||
<-errc
|
||||
}
|
||||
|
||||
// IsWebSocketRequest returns a boolean indicating whether the request has the
|
||||
// headers of a WebSocket handshake request.
|
||||
func IsWebSocketRequest(r *http.Request) bool {
|
||||
contains := func(key, val string) bool {
|
||||
vv := strings.Split(r.Header.Get(key), ",")
|
||||
for _, v := range vv {
|
||||
if val == strings.ToLower(strings.TrimSpace(v)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
if !contains("Connection", "upgrade") {
|
||||
return false
|
||||
}
|
||||
if !contains("Upgrade", "websocket") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
Loading…
Reference in New Issue