187 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			187 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Go
		
	
	
	
| /* Copyright © INFINI Ltd. All rights reserved.
 | |
|  * Web: https://infinilabs.com
 | |
|  * Email: hello#infini.ltd */
 | |
| 
 | |
| package gateway
 | |
| 
 | |
| import (
 | |
| 	"crypto/tls"
 | |
| 	"io"
 | |
| 	log "github.com/cihub/seelog"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"net/url"
 | |
| 	"strings"
 | |
| )
 | |
| 
 | |
| // ReverseProxy is a WebSocket reverse proxy. It will not work with a regular
 | |
| // HTTP request, so it is the caller's responsiblity to ensure the incoming
 | |
| // request is a WebSocket request.
 | |
| type ReverseProxy struct {
 | |
| 	// Director must be a function which modifies
 | |
| 	// the request into a new request to be sent
 | |
| 	// using Transport. Its response is then copied
 | |
| 	// back to the original client unmodified.
 | |
| 	Director func(*http.Request)
 | |
| 
 | |
| 	// Dial specifies the dial function for dialing the proxied
 | |
| 	// server over tcp.
 | |
| 	// If Dial is nil, net.Dial is used.
 | |
| 	Dial func(network, addr string) (net.Conn, error)
 | |
| 
 | |
| 	// TLSClientConfig specifies the TLS configuration to use for 'wss'.
 | |
| 	// If nil, the default configuration is used.
 | |
| 	TLSClientConfig *tls.Config
 | |
| 
 | |
| 	// ErrorLog specifies an optional logger for errors
 | |
| 	// that occur when attempting to proxy the request.
 | |
| 	// If nil, logging goes to os.Stderr via the log package's
 | |
| 	// standard logger.
 | |
| 	ErrorLog *log.LoggerInterface
 | |
| }
 | |
| 
 | |
| // stolen from net/http/httputil. singleJoiningSlash ensures that the route
 | |
| // '/a/' joined with '/b' becomes '/a/b'.
 | |
| func singleJoiningSlash(a, b string) string {
 | |
| 	aslash := strings.HasSuffix(a, "/")
 | |
| 	bslash := strings.HasPrefix(b, "/")
 | |
| 	switch {
 | |
| 	case aslash && bslash:
 | |
| 		return a + b[1:]
 | |
| 	case !aslash && !bslash:
 | |
| 		return a + "/" + b
 | |
| 	}
 | |
| 	return a + b
 | |
| }
 | |
| 
 | |
| // NewSingleHostReverseProxy returns a new websocket ReverseProxy. The path
 | |
| // rewrites follow the same rules as the httputil.ReverseProxy. If the target
 | |
| // url has the path '/foo' and the incoming request '/bar', the request path
 | |
| // will be updated to '/foo/bar' before forwarding.
 | |
| // Scheme should specify if 'ws' or 'wss' should be used.
 | |
| func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
 | |
| 	targetQuery := target.RawQuery
 | |
| 	director := func(req *http.Request) {
 | |
| 		req.URL.Scheme = target.Scheme
 | |
| 		req.URL.Host = target.Host
 | |
| 		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
 | |
| 		if targetQuery == "" || req.URL.RawQuery == "" {
 | |
| 			req.URL.RawQuery = targetQuery + req.URL.RawQuery
 | |
| 		} else {
 | |
| 			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
 | |
| 		}
 | |
| 	}
 | |
| 	return &ReverseProxy{Director: director}
 | |
| }
 | |
| 
 | |
| // Function to implement the http.Handler interface.
 | |
| func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | |
| 	logFunc := log.Errorf
 | |
| 
 | |
| 	if !IsWebSocketRequest(r) {
 | |
| 		http.Error(w, "Cannot handle non-WebSocket requests", 500)
 | |
| 		logFunc("Received a request that was not a WebSocket request")
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	outreq := new(http.Request)
 | |
| 	// shallow copying
 | |
| 	*outreq = *r
 | |
| 	p.Director(outreq)
 | |
| 	host := outreq.URL.Host
 | |
| 
 | |
| 	if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
 | |
| 		// If we aren't the first proxy retain prior
 | |
| 		// X-Forwarded-For information as a comma+space
 | |
| 		// separated list and fold multiple headers into one.
 | |
| 		if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
 | |
| 			clientIP = strings.Join(prior, ", ") + ", " + clientIP
 | |
| 		}
 | |
| 		outreq.Header.Set("X-Forwarded-For", clientIP)
 | |
| 	}
 | |
| 
 | |
| 	dial := p.Dial
 | |
| 	if dial == nil {
 | |
| 		dial = net.Dial
 | |
| 	}
 | |
| 
 | |
| 	// if host does not specify a port, use the default http port
 | |
| 	if !strings.Contains(host, ":") {
 | |
| 		if outreq.URL.Scheme == "wss" {
 | |
| 			host = host + ":443"
 | |
| 		} else {
 | |
| 			host = host + ":80"
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if outreq.URL.Scheme == "wss" {
 | |
| 		var tlsConfig *tls.Config
 | |
| 		if p.TLSClientConfig == nil {
 | |
| 			tlsConfig = &tls.Config{}
 | |
| 		} else {
 | |
| 			tlsConfig = p.TLSClientConfig
 | |
| 		}
 | |
| 		dial = func(network, address string) (net.Conn, error) {
 | |
| 			return tls.Dial("tcp", host, tlsConfig)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	d, err := dial("tcp", host)
 | |
| 	if err != nil {
 | |
| 		http.Error(w, "Error forwarding request.", 500)
 | |
| 		logFunc("Error dialing websocket backend %s: %v", outreq.URL, err)
 | |
| 		return
 | |
| 	}
 | |
| 	// All request generated by the http package implement this interface.
 | |
| 	hj, ok := w.(http.Hijacker)
 | |
| 	if !ok {
 | |
| 		http.Error(w, "Not a hijacker?", 500)
 | |
| 		return
 | |
| 	}
 | |
| 	// Hijack() tells the http package not to do anything else with the connection.
 | |
| 	// After, it bcomes this functions job to manage it. `nc` is of type *net.Conn.
 | |
| 	nc, _, err := hj.Hijack()
 | |
| 	if err != nil {
 | |
| 		logFunc("Hijack error: %v", err)
 | |
| 		return
 | |
| 	}
 | |
| 	defer nc.Close() // must close the underlying net connection after hijacking
 | |
| 	defer d.Close()
 | |
| 
 | |
| 	// write the modified incoming request to the dialed connection
 | |
| 	err = outreq.Write(d)
 | |
| 	if err != nil {
 | |
| 		logFunc("Error copying request to target: %v", err)
 | |
| 		return
 | |
| 	}
 | |
| 	errc := make(chan error, 2)
 | |
| 	cp := func(dst io.Writer, src io.Reader) {
 | |
| 		_, err := io.Copy(dst, src)
 | |
| 		errc <- err
 | |
| 	}
 | |
| 	go cp(d, nc)
 | |
| 	go cp(nc, d)
 | |
| 	<-errc
 | |
| }
 | |
| 
 | |
| // IsWebSocketRequest returns a boolean indicating whether the request has the
 | |
| // headers of a WebSocket handshake request.
 | |
| func IsWebSocketRequest(r *http.Request) bool {
 | |
| 	contains := func(key, val string) bool {
 | |
| 		vv := strings.Split(r.Header.Get(key), ",")
 | |
| 		for _, v := range vv {
 | |
| 			if val == strings.ToLower(strings.TrimSpace(v)) {
 | |
| 				return true
 | |
| 			}
 | |
| 		}
 | |
| 		return false
 | |
| 	}
 | |
| 	if !contains("Connection", "upgrade") {
 | |
| 		return false
 | |
| 	}
 | |
| 	if !contains("Upgrade", "websocket") {
 | |
| 		return false
 | |
| 	}
 | |
| 	return true
 | |
| }
 |