console/modules/elastic/api/proxy.go

284 lines
8.1 KiB
Go

package api
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"github.com/buger/jsonparser"
log "github.com/cihub/seelog"
"github.com/segmentio/encoding/json"
"infini.sh/framework/core/api"
httprouter "infini.sh/framework/core/api/router"
"infini.sh/framework/core/elastic"
"infini.sh/framework/core/global"
"infini.sh/framework/core/util"
"infini.sh/framework/lib/fasthttp"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
var httpPool = fasthttp.NewRequestResponsePool("proxy_search")
func (h *APIHandler) HandleProxyAction(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
resBody := map[string]interface{}{}
targetClusterID := ps.ByName("id")
method := h.GetParameterOrDefault(req, "method", "")
path := h.GetParameterOrDefault(req, "path", "")
if method == "" || path == "" {
resBody["error"] = fmt.Errorf("parameter method and path is required")
h.WriteJSON(w, resBody, http.StatusInternalServerError)
return
}
exists, esClient, err := h.GetClusterClient(targetClusterID)
if err != nil {
log.Error(err)
resBody["error"] = err.Error()
h.WriteJSON(w, resBody, http.StatusInternalServerError)
return
}
if !exists {
resBody["error"] = fmt.Sprintf("cluster [%s] not found", targetClusterID)
log.Error(resBody["error"])
h.WriteJSON(w, resBody, http.StatusNotFound)
return
}
authPath, _ := url.PathUnescape(path)
var realPath = authPath
newURL, err := url.Parse(realPath)
if err != nil {
log.Error(err)
resBody["error"] = err.Error()
h.WriteJSON(w, resBody, http.StatusInternalServerError)
return
}
if strings.Trim(newURL.Path, "/") == "_sql" {
distribution := esClient.GetVersion().Distribution
indexName, err := rewriteTableNamesOfSqlRequest(req, distribution)
if err != nil {
h.WriteError(w, err.Error(), http.StatusInternalServerError)
return
}
if !h.IsIndexAllowed(req, targetClusterID, indexName) {
h.WriteError(w, fmt.Sprintf("forbidden to access index %s", indexName), http.StatusForbidden)
return
}
q, _ := url.ParseQuery(newURL.RawQuery)
hasFormat := q.Has("format")
switch distribution {
case elastic.Opensearch:
path = "_plugins/_sql?format=raw"
case elastic.Easysearch:
if !hasFormat {
q.Add("format", "raw")
}
path = "_sql?" + q.Encode()
default:
if !hasFormat {
q.Add("format", "txt")
}
path = "_sql?" + q.Encode()
}
}
//ccs search
if parts := strings.SplitN(authPath, "/", 2); strings.Contains(parts[0], ":") {
ccsParts := strings.SplitN(parts[0], ":", 2)
realPath = fmt.Sprintf("%s/%s", ccsParts[1], parts[1])
}
newReq := req.Clone(context.Background())
newReq.URL = newURL
newReq.Method = method
isSuperAdmin, permission, err := h.ValidateProxyRequest(newReq, targetClusterID)
if err != nil {
log.Error(err)
resBody["error"] = err.Error()
h.WriteJSON(w, resBody, http.StatusForbidden)
return
}
if permission == "" && api.IsAuthEnable() && !isSuperAdmin {
resBody["error"] = "unknown request path"
h.WriteJSON(w, resBody, http.StatusForbidden)
return
}
//if permission != "" {
// if permission == "cat.indices" || permission == "cat.shards" {
// reqUrl.Path
// }
//}
var (
freq = httpPool.AcquireRequest()
fres = httpPool.AcquireResponse()
)
defer func() {
httpPool.ReleaseRequest(freq)
httpPool.ReleaseResponse(fres)
}()
metadata := elastic.GetMetadata(targetClusterID)
if metadata == nil {
resBody["error"] = fmt.Sprintf("cluster [%s] metadata not found", targetClusterID)
log.Error(resBody["error"])
h.WriteJSON(w, resBody, http.StatusNotFound)
return
}
if metadata.Config.BasicAuth != nil {
freq.SetBasicAuth(metadata.Config.BasicAuth.Username, metadata.Config.BasicAuth.Password.Get())
}
endpoint := util.JoinPath(metadata.GetActivePreferredSeedEndpoint(), path)
freq.SetRequestURI(endpoint)
method = strings.ToUpper(method)
freq.Header.SetMethod(method)
freq.Header.SetUserAgent(req.Header.Get("user-agent"))
freq.Header.SetReferer(endpoint)
rurl, _ := url.Parse(endpoint)
if rurl != nil {
freq.Header.SetHost(rurl.Host)
freq.Header.SetRequestURI(rurl.RequestURI())
}
clonedURI := freq.CloneURI()
defer fasthttp.ReleaseURI(clonedURI)
clonedURI.SetScheme(metadata.GetSchema())
freq.SetURI(clonedURI)
if permission == "cluster.search" {
indices, hasAll := h.GetAllowedIndices(req, targetClusterID)
if !hasAll && len(indices) == 0 {
h.WriteJSON(w, elastic.SearchResponse{}, http.StatusOK)
return
}
if hasAll {
freq.SetBodyStream(req.Body, int(req.ContentLength))
} else {
body, err := io.ReadAll(req.Body)
if err != nil {
log.Error(err)
h.WriteError(w, err.Error(), http.StatusInternalServerError)
return
}
if len(body) == 0 {
body = []byte("{}")
}
v, _, _, _ := jsonparser.Get(body, "query")
newQ := bytes.NewBuffer([]byte(`{"bool": {"must": [{"terms": {"_index":`))
indicesBytes := util.MustToJSONBytes(indices)
newQ.Write(indicesBytes)
newQ.Write([]byte("}}"))
if len(v) > 0 {
newQ.Write([]byte(","))
newQ.Write(v)
}
newQ.Write([]byte(`]}}`))
body, _ = jsonparser.Set(body, newQ.Bytes(), "query")
freq.SetBody(body)
}
} else {
freq.SetBodyStream(req.Body, int(req.ContentLength))
}
defer req.Body.Close()
err = getHttpClient().Do(freq, fres)
if err != nil {
resBody["error"] = err.Error()
h.WriteJSON(w, resBody, http.StatusInternalServerError)
return
}
okBody := struct {
RequestHeader string `json:"request_header"`
ResponseHeader string `json:"response_header"`
ResponseBody string `json:"response_body"`
}{
RequestHeader: freq.Header.String(),
ResponseHeader: fres.Header.String(),
ResponseBody: string(fres.GetRawBody()),
}
w.Header().Set("Content-type", string(fres.Header.ContentType()))
w.WriteHeader(fres.StatusCode())
json.NewEncoder(w).Encode(okBody)
}
func rewriteTableNamesOfSqlRequest(req *http.Request, distribution string) (string, error) {
var buf bytes.Buffer
if _, err := buf.ReadFrom(req.Body); err != nil {
return "", err
}
if err := req.Body.Close(); err != nil {
return "", err
}
req.Body = io.NopCloser(bytes.NewReader(buf.Bytes()))
sqlQuery, err := jsonparser.GetString(buf.Bytes(), "query")
if err != nil {
return "", fmt.Errorf("parse query from request body error: %w", err)
}
q := util.NewSQLQueryString(sqlQuery)
tableNames, err := q.TableNames()
if err != nil {
return "", err
}
rewriteBody := false
switch distribution {
case elastic.Elasticsearch:
for _, tname := range tableNames {
if strings.ContainsAny(tname, "-.") && !strings.HasPrefix(tname, "\"") {
//append quotes from table name
sqlQuery = strings.Replace(sqlQuery, tname, fmt.Sprintf(`\"%s\"`, tname), -1)
rewriteBody = true
}
}
case elastic.Opensearch, elastic.Easysearch:
for _, tname := range tableNames {
//remove quotes from table name
if strings.HasPrefix(tname, "\"") || strings.HasSuffix(tname, "\"") {
sqlQuery = strings.Replace(sqlQuery, tname, strings.Trim(tname, "\""), -1)
rewriteBody = true
}
}
}
if rewriteBody {
sqlQuery = fmt.Sprintf(`"%s"`, sqlQuery)
reqBody, _ := jsonparser.Set(buf.Bytes(), []byte(sqlQuery), "query")
req.Body = io.NopCloser(bytes.NewReader(reqBody))
req.ContentLength = int64(len(reqBody))
}
var unescapedTableNames []string
for _, tname := range tableNames {
unescapedTableNames = append(unescapedTableNames, strings.Trim(tname, "\""))
}
return strings.Join(unescapedTableNames, ","), nil
}
var (
client *fasthttp.Client
clientOnce sync.Once
)
func getHttpClient() *fasthttp.Client {
clientOnce.Do(func() {
clientCfg := global.Env().SystemConfig.HTTPClientConfig
client = &fasthttp.Client{
MaxConnsPerHost: clientCfg.MaxConnectionPerHost,
TLSConfig: &tls.Config{InsecureSkipVerify: clientCfg.TLSConfig.TLSInsecureSkipVerify},
ReadTimeout: util.GetDurationOrDefault(clientCfg.ReadTimeout, 60*time.Second),
WriteTimeout: util.GetDurationOrDefault(clientCfg.ReadTimeout, 60*time.Second),
DialDualStack: true,
ReadBufferSize: clientCfg.ReadBufferSize,
WriteBufferSize: clientCfg.WriteBufferSize,
//Dial: fasthttpproxy.FasthttpProxyHTTPDialerTimeout(time.Second * 2),
}
})
return client
}