console/modules/elastic/api/proxy.go

291 lines
8.5 KiB
Go

// Copyright (C) INFINI Labs & INFINI LIMITED.
//
// The INFINI Console is offered under the GNU Affero General Public License v3.0
// and as commercial software.
//
// For commercial licensing, contact us at:
// - Website: infinilabs.com
// - Email: hello@infini.ltd
//
// Open Source licensed under AGPL V3:
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package api
import (
"bytes"
"context"
"fmt"
"github.com/buger/jsonparser"
log "github.com/cihub/seelog"
"github.com/segmentio/encoding/json"
"infini.sh/framework/core/api"
httprouter "infini.sh/framework/core/api/router"
"infini.sh/framework/core/elastic"
"infini.sh/framework/core/util"
"infini.sh/framework/lib/fasthttp"
"io"
"net/http"
"net/url"
"strings"
)
var httpPool = fasthttp.NewRequestResponsePool("proxy_search")
func (h *APIHandler) HandleProxyAction(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
resBody := map[string]interface{}{}
targetClusterID := ps.ByName("id")
method := h.GetParameterOrDefault(req, "method", "")
path := h.GetParameterOrDefault(req, "path", "")
if method == "" || path == "" {
resBody["error"] = fmt.Errorf("parameter method and path is required")
h.WriteJSON(w, resBody, http.StatusInternalServerError)
return
}
exists, esClient, err := h.GetClusterClient(targetClusterID)
if err != nil {
log.Error(err)
resBody["error"] = err.Error()
h.WriteJSON(w, resBody, http.StatusInternalServerError)
return
}
if !exists {
resBody["error"] = fmt.Sprintf("cluster [%s] not found", targetClusterID)
log.Error(resBody["error"])
h.WriteJSON(w, resBody, http.StatusNotFound)
return
}
authPath, _ := url.PathUnescape(path)
var realPath = authPath
newURL, err := url.Parse(realPath)
if err != nil {
log.Error(err)
resBody["error"] = err.Error()
h.WriteJSON(w, resBody, http.StatusInternalServerError)
return
}
if strings.Trim(newURL.Path, "/") == "_sql" {
distribution := esClient.GetVersion().Distribution
version := esClient.GetVersion().Number
indexName, err := rewriteTableNamesOfSqlRequest(req, distribution)
if err != nil {
h.WriteError(w, err.Error(), http.StatusInternalServerError)
return
}
if !h.IsIndexAllowed(req, targetClusterID, indexName) {
h.WriteError(w, fmt.Sprintf("forbidden to access index %s", indexName), http.StatusForbidden)
return
}
q, _ := url.ParseQuery(newURL.RawQuery)
hasFormat := q.Has("format")
switch distribution {
case elastic.Elasticsearch:
if !hasFormat {
q.Add("format", "txt")
}
if large, _ := util.VersionCompare(version, "7.0.0"); large > 0 {
path = "_sql?" + q.Encode()
} else {
path = "_xpack/_sql?" + q.Encode()
}
case elastic.Opensearch:
path = "_plugins/_sql?format=raw"
case elastic.Easysearch:
if !hasFormat {
q.Add("format", "raw")
}
path = "_sql?" + q.Encode()
default:
if !hasFormat {
q.Add("format", "txt")
}
path = "_sql?" + q.Encode()
}
}
//ccs search
if parts := strings.SplitN(authPath, "/", 2); strings.Contains(parts[0], ":") {
ccsParts := strings.SplitN(parts[0], ":", 2)
realPath = fmt.Sprintf("%s/%s", ccsParts[1], parts[1])
}
newReq := req.Clone(context.Background())
newReq.URL = newURL
newReq.Method = method
isSuperAdmin, permission, err := h.ValidateProxyRequest(newReq, targetClusterID)
if err != nil {
log.Error(err)
resBody["error"] = err.Error()
h.WriteJSON(w, resBody, http.StatusForbidden)
return
}
if permission == "" && api.IsAuthEnable() && !isSuperAdmin {
resBody["error"] = "unknown request path"
h.WriteJSON(w, resBody, http.StatusForbidden)
return
}
//if permission != "" {
// if permission == "cat.indices" || permission == "cat.shards" {
// reqUrl.Path
// }
//}
var (
freq = httpPool.AcquireRequest()
fres = httpPool.AcquireResponse()
)
defer func() {
httpPool.ReleaseRequest(freq)
httpPool.ReleaseResponse(fres)
}()
metadata := elastic.GetMetadata(targetClusterID)
if metadata == nil {
resBody["error"] = fmt.Sprintf("cluster [%s] metadata not found", targetClusterID)
log.Error(resBody["error"])
h.WriteJSON(w, resBody, http.StatusNotFound)
return
}
if metadata.Config.BasicAuth != nil {
freq.SetBasicAuth(metadata.Config.BasicAuth.Username, metadata.Config.BasicAuth.Password.Get())
}
endpoint := util.JoinPath(metadata.GetActivePreferredSeedEndpoint(), path)
freq.SetRequestURI(endpoint)
method = strings.ToUpper(method)
freq.Header.SetMethod(method)
freq.Header.SetUserAgent(req.Header.Get("user-agent"))
freq.Header.SetReferer(endpoint)
rurl, _ := url.Parse(endpoint)
if rurl != nil {
freq.Header.SetHost(rurl.Host)
freq.Header.SetRequestURI(rurl.RequestURI())
}
clonedURI := freq.CloneURI()
defer fasthttp.ReleaseURI(clonedURI)
clonedURI.SetScheme(metadata.GetSchema())
freq.SetURI(clonedURI)
if permission == "cluster.search" {
indices, hasAll := h.GetAllowedIndices(req, targetClusterID)
if !hasAll && len(indices) == 0 {
h.WriteJSON(w, elastic.SearchResponse{}, http.StatusOK)
return
}
if hasAll {
freq.SetBodyStream(req.Body, int(req.ContentLength))
} else {
body, err := io.ReadAll(req.Body)
if err != nil {
log.Error(err)
h.WriteError(w, err.Error(), http.StatusInternalServerError)
return
}
if len(body) == 0 {
body = []byte("{}")
}
v, _, _, _ := jsonparser.Get(body, "query")
newQ := bytes.NewBuffer([]byte(`{"bool": {"must": [{"terms": {"_index":`))
indicesBytes := util.MustToJSONBytes(indices)
newQ.Write(indicesBytes)
newQ.Write([]byte("}}"))
if len(v) > 0 {
newQ.Write([]byte(","))
newQ.Write(v)
}
newQ.Write([]byte(`]}}`))
body, _ = jsonparser.Set(body, newQ.Bytes(), "query")
freq.SetBody(body)
}
} else {
freq.SetBodyStream(req.Body, int(req.ContentLength))
}
defer req.Body.Close()
err = api.GetFastHttpClient("elasticsearch_proxy").Do(freq, fres)
if err != nil {
resBody["error"] = err.Error()
h.WriteJSON(w, resBody, http.StatusInternalServerError)
return
}
okBody := struct {
RequestHeader string `json:"request_header"`
ResponseHeader string `json:"response_header"`
ResponseBody string `json:"response_body"`
}{
RequestHeader: freq.Header.String(),
ResponseHeader: fres.Header.String(),
ResponseBody: string(fres.GetRawBody()),
}
w.Header().Set("Content-type", string(fres.Header.ContentType()))
w.WriteHeader(fres.StatusCode())
json.NewEncoder(w).Encode(okBody)
}
func rewriteTableNamesOfSqlRequest(req *http.Request, distribution string) (string, error) {
var buf bytes.Buffer
if _, err := buf.ReadFrom(req.Body); err != nil {
return "", err
}
if err := req.Body.Close(); err != nil {
return "", err
}
req.Body = io.NopCloser(bytes.NewReader(buf.Bytes()))
sqlQuery, err := jsonparser.GetString(buf.Bytes(), "query")
if err != nil {
return "", fmt.Errorf("parse query from request body error: %w", err)
}
q := util.NewSQLQueryString(sqlQuery)
tableNames, err := q.TableNames()
if err != nil {
return "", err
}
rewriteBody := false
switch distribution {
case elastic.Elasticsearch:
for _, tname := range tableNames {
if strings.ContainsAny(tname, "-.") && !strings.HasPrefix(tname, "\"") {
//append quotes from table name
sqlQuery = strings.Replace(sqlQuery, tname, fmt.Sprintf(`\"%s\"`, tname), -1)
rewriteBody = true
}
}
case elastic.Opensearch, elastic.Easysearch:
for _, tname := range tableNames {
//remove quotes from table name
if strings.HasPrefix(tname, "\"") || strings.HasSuffix(tname, "\"") {
sqlQuery = strings.Replace(sqlQuery, tname, strings.Trim(tname, "\""), -1)
rewriteBody = true
}
}
}
if rewriteBody {
sqlQuery = fmt.Sprintf(`"%s"`, sqlQuery)
reqBody, _ := jsonparser.Set(buf.Bytes(), []byte(sqlQuery), "query")
req.Body = io.NopCloser(bytes.NewReader(reqBody))
req.ContentLength = int64(len(reqBody))
}
var unescapedTableNames []string
for _, tname := range tableNames {
unescapedTableNames = append(unescapedTableNames, strings.Trim(tname, "\""))
}
return strings.Join(unescapedTableNames, ","), nil
}