From ac683e2329098c937fa794c64d608ca95347bce7 Mon Sep 17 00:00:00 2001 From: qiwang <1364512070@qq.com> Date: Tue, 25 Jun 2024 10:09:01 +0800 Subject: [PATCH 01/21] fix: update taskai 0625 Former-commit-id: 1e45411185ec160a4d2bd2bb0dac1d8f11982990 --- api/Dockerfile | 4 +- api/etc/pcm.yaml | 4 +- .../logic/inference/imageinferencelogic.go | 94 +++---------------- pkg/utils/aksk_sign.go | 84 +++++++++++++++++ 4 files changed, 102 insertions(+), 84 deletions(-) create mode 100644 pkg/utils/aksk_sign.go diff --git a/api/Dockerfile b/api/Dockerfile index 3c78b3fd..a27e861c 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.22-alpine3.18 AS builder +FROM golang:1.22.4-alpine3.20 AS builder WORKDIR /app @@ -9,7 +9,7 @@ RUN go env -w GO111MODULE=on \ && go env -w CGO_ENABLED=0 RUN go build -o pcm-coordinator-api /app/api/pcm.go -FROM alpine:3.18 +FROM alpine:3.20 WORKDIR /app diff --git a/api/etc/pcm.yaml b/api/etc/pcm.yaml index ef1e3250..696cd563 100644 --- a/api/etc/pcm.yaml +++ b/api/etc/pcm.yaml @@ -6,8 +6,8 @@ MaxBytes: 524288000 Timeout: 50000 DB: - DataSource: root:uJpLd6u-J?HC1@(10.206.0.12:3306)/pcm?parseTime=true&loc=Local - # DataSource: root:uJpLd6u-J?HC1@(47.92.88.143:3306)/pcm?parseTime=true&loc=Local + #DataSource: root:uJpLd6u-J?HC1@(10.206.0.12:3306)/pcm?parseTime=true&loc=Local + DataSource: root:uJpLd6u-J?HC1@(47.92.88.143:3306)/pcm?parseTime=true&loc=Local Redis: Host: 10.206.0.12:6379 Pass: redisPW123 diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index f52b580d..19d68509 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -1,12 +1,8 @@ package inference import ( - "bytes" "context" - "crypto/tls" "errors" - "fmt" - "github.com/JCCE-nudt/apigw-go-sdk/core" "github.com/go-resty/resty/v2" "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" @@ -16,7 +12,7 @@ import ( "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" - "io" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" "k8s.io/apimachinery/pkg/util/json" "log" "math/rand" @@ -459,7 +455,19 @@ func getInferResult(url string, file multipart.File, fileName string, clusterNam func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { var res Res - body, err := SendRequest("POST", url, file, fileName) + /* req := GetRestyRequest(20) + _, err := req. + SetFileReader("file", fileName, file). + SetHeaders(map[string]string{ + "ak": "UNEHPHO4Z7YSNPKRXFE4", + "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", + }). + SetResult(&res). + Post(url) + if err != nil { + return "", err + }*/ + body, err := utils.SendRequest("POST", url, file, fileName) if err != nil { return "", err } @@ -467,83 +475,9 @@ func getInferResultModelarts(url string, file multipart.File, fileName string) ( if errjson != nil { log.Fatalf("Error parsing JSON: %s", errjson) } - return res.Result, nil } -// SignClient AK/SK签名认证 -func SignClient(r *http.Request, writer *multipart.Writer) (*http.Client, error) { - r.Header.Add("content-type", "application/json;charset=UTF-8") - r.Header.Add("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52") - r.Header.Add("x-stage", "RELEASE") - r.Header.Add("x-sdk-content-sha256", "UNSIGNED-PAYLOAD") - r.Header.Set("Content-Type", writer.FormDataContentType()) - s := core.Signer{ - Key: "UNEHPHO4Z7YSNPKRXFE4", - Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", - } - err := s.Sign(r) - if err != nil { - return nil, err - } - - //设置client信任所有证书 - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - client := &http.Client{ - Transport: tr, - } - return client, nil -} - -func SendRequest(method, url string, file multipart.File, fileName string) (string, error) { - /*body := &bytes.Buffer{} - writer := multipart.NewWriter(body)*/ - // 创建一个新的缓冲区以写入multipart表单 - var body bytes.Buffer - // 创建一个新的multipart writer - writer := multipart.NewWriter(&body) - // 创建一个用于写入文件的表单字段 - part, err := writer.CreateFormFile("file", fileName) // "file"是表单的字段名,第二个参数是文件名 - if err != nil { - fmt.Println("Error creating form file:", err) - } - // 将文件的内容拷贝到multipart writer中 - _, err = io.Copy(part, file) - if err != nil { - fmt.Println("Error copying file data:", err) - - } - err = writer.Close() - if err != nil { - fmt.Println("Error closing multipart writer:", err) - } - request, err := http.NewRequest(method, url, &body) - if err != nil { - fmt.Println("Error creating new request:", err) - //return nil, err - } - signedR, err := SignClient(request, writer) - if err != nil { - fmt.Println("Error signing request:", err) - //return nil, err - } - - res, err := signedR.Do(request) - if err != nil { - fmt.Println("Error sending request:", err) - return "", err - } - //defer res.Body.Close() - Resbody, err := io.ReadAll(res.Body) - if err != nil { - fmt.Println("Error reading response body:", err) - //return nil, err - } - return string(Resbody), nil -} - func GetRestyRequest(timeoutSeconds int64) *resty.Request { client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) request := client.R() diff --git a/pkg/utils/aksk_sign.go b/pkg/utils/aksk_sign.go new file mode 100644 index 00000000..a511129f --- /dev/null +++ b/pkg/utils/aksk_sign.go @@ -0,0 +1,84 @@ +package utils + +import ( + "bytes" + "crypto/tls" + "fmt" + "github.com/JCCE-nudt/apigw-go-sdk/core" + "io" + "mime/multipart" + "net/http" +) + +// SignClient AK/SK签名认证 +func SignClient(r *http.Request, writer *multipart.Writer) (*http.Client, error) { + r.Header.Add("content-type", "application/json;charset=UTF-8") + r.Header.Add("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52") + r.Header.Add("x-stage", "RELEASE") + r.Header.Add("x-sdk-content-sha256", "UNSIGNED-PAYLOAD") + //r.Header.Set("Content-Type", writer.FormDataContentType()) + s := core.Signer{ + Key: "UNEHPHO4Z7YSNPKRXFE4", + Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", + } + err := s.Sign(r) + if err != nil { + return nil, err + } + + //设置client信任所有证书 + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{ + Transport: tr, + } + return client, nil +} + +func SendRequest(method, url string, file multipart.File, fileName string) (string, error) { + /*body := &bytes.Buffer{} + writer := multipart.NewWriter(body)*/ + // 创建一个新的缓冲区以写入multipart表单 + var body bytes.Buffer + // 创建一个新的multipart writer + writer := multipart.NewWriter(&body) + // 创建一个用于写入文件的表单字段 + part, err := writer.CreateFormFile("file", fileName) // "file"是表单的字段名,第二个参数是文件名 + if err != nil { + fmt.Println("Error creating form file:", err) + } + // 将文件的内容拷贝到multipart writer中 + _, err = io.Copy(part, file) + if err != nil { + fmt.Println("Error copying file data:", err) + + } + err = writer.Close() + if err != nil { + fmt.Println("Error closing multipart writer:", err) + } + request, err := http.NewRequest(method, url, &body) + if err != nil { + fmt.Println("Error creating new request:", err) + //return nil, err + } + signedR, err := SignClient(request, writer) + if err != nil { + fmt.Println("Error signing request:", err) + //return nil, err + } + + res, err := signedR.Do(request) + if err != nil { + fmt.Println("Error sending request:", err) + return "", err + } + //defer res.Body.Close() + Resbody, err := io.ReadAll(res.Body) + if err != nil { + fmt.Println("Error reading response body:", err) + //return nil, err + } + return string(Resbody), nil +} From 49c415adabf3190b43014750bd62a5507e3c451c Mon Sep 17 00:00:00 2001 From: qiwang <1364512070@qq.com> Date: Tue, 25 Jun 2024 10:11:47 +0800 Subject: [PATCH 02/21] fix: update taskai 0625 Former-commit-id: 712cc1b7bc7b6419c0f8d28ff7e1c4317bcc5252 --- api/etc/pcm.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/etc/pcm.yaml b/api/etc/pcm.yaml index 696cd563..ef1e3250 100644 --- a/api/etc/pcm.yaml +++ b/api/etc/pcm.yaml @@ -6,8 +6,8 @@ MaxBytes: 524288000 Timeout: 50000 DB: - #DataSource: root:uJpLd6u-J?HC1@(10.206.0.12:3306)/pcm?parseTime=true&loc=Local - DataSource: root:uJpLd6u-J?HC1@(47.92.88.143:3306)/pcm?parseTime=true&loc=Local + DataSource: root:uJpLd6u-J?HC1@(10.206.0.12:3306)/pcm?parseTime=true&loc=Local + # DataSource: root:uJpLd6u-J?HC1@(47.92.88.143:3306)/pcm?parseTime=true&loc=Local Redis: Host: 10.206.0.12:6379 Pass: redisPW123 From b880b5b2ebc214e7dfcd0cf050a96496d53fc41d Mon Sep 17 00:00:00 2001 From: qiwang <1364512070@qq.com> Date: Tue, 25 Jun 2024 10:36:50 +0800 Subject: [PATCH 03/21] fix: update taskai 0625 Former-commit-id: 0297a4c6a4fb5a2e166041c8f96ccdde457593a0 --- pkg/utils/aksk_sign.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/utils/aksk_sign.go b/pkg/utils/aksk_sign.go index a511129f..c88d3b85 100644 --- a/pkg/utils/aksk_sign.go +++ b/pkg/utils/aksk_sign.go @@ -16,7 +16,7 @@ func SignClient(r *http.Request, writer *multipart.Writer) (*http.Client, error) r.Header.Add("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52") r.Header.Add("x-stage", "RELEASE") r.Header.Add("x-sdk-content-sha256", "UNSIGNED-PAYLOAD") - //r.Header.Set("Content-Type", writer.FormDataContentType()) + r.Header.Set("Content-Type", writer.FormDataContentType()) s := core.Signer{ Key: "UNEHPHO4Z7YSNPKRXFE4", Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", From 2cd7f793b4d869cf202268877762bd7aafb7c1eb Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 15:53:12 +0800 Subject: [PATCH 04/21] fix octopus imageUrls bugs Former-commit-id: 80d170216054490c90601bb94085620e1edd70ec --- api/internal/storeLink/octopus.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/api/internal/storeLink/octopus.go b/api/internal/storeLink/octopus.go index 6219f86a..a902cf93 100644 --- a/api/internal/storeLink/octopus.go +++ b/api/internal/storeLink/octopus.go @@ -884,7 +884,7 @@ func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.Infer var imageUrls []*collector.ImageInferUrl for _, notebook := range list.Payload.GetNotebooks() { - if strings.Contains(notebook.AlgorithmName, option.ModelName) { + if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" { url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1) names := strings.Split(notebook.AlgorithmName, UNDERSCORE) imageUrl := &collector.ImageInferUrl{ @@ -896,5 +896,9 @@ func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.Infer continue } } + + if len(imageUrls) == 0 { + return nil, errors.New("no infer url available") + } return imageUrls, nil } From 1ee960f52ada14d4357d7c2017ff8c3d475c6bbf Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 17:06:40 +0800 Subject: [PATCH 05/21] added textinference api Former-commit-id: d32678e386f68034fa9f932975c2d09c6c825e2c --- api/desc/inference/inference.api | 12 +++++++- api/desc/pcm.api | 3 ++ .../inference/texttotextinferencehandler.go | 28 +++++++++++++++++ api/internal/handler/routes.go | 5 ++++ .../inference/texttotextinferencelogic.go | 30 +++++++++++++++++++ api/internal/types/types.go | 12 ++++++++ 6 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 api/internal/handler/inference/texttotextinferencehandler.go create mode 100644 api/internal/logic/inference/texttotextinferencelogic.go diff --git a/api/desc/inference/inference.api b/api/desc/inference/inference.api index ba9b1946..1872d2c6 100644 --- a/api/desc/inference/inference.api +++ b/api/desc/inference/inference.api @@ -63,7 +63,17 @@ type ( clusterName string `json:"clusterName"` } + /******************TextToText inference*************************/ + TextToTextInferenceReq{ + TaskName string `form:"taskName"` + TaskDesc string `form:"taskDesc"` + ModelName string `form:"modelName"` + ModelType string `form:"modelType"` + AdapterId string `form:"adapterId"` + ClusterId string `form:"clusterIds"` + } + TextToTextInferenceResp{ - + } ) diff --git a/api/desc/pcm.api b/api/desc/pcm.api index b912fda9..0a7b51bf 100644 --- a/api/desc/pcm.api +++ b/api/desc/pcm.api @@ -907,6 +907,9 @@ service pcm { group: inference ) service pcm { + @handler TextToTextInferenceHandler + post /inference/text (TextToTextInferenceReq) returns (TextToTextInferenceResp) + @handler ImageInferenceHandler post /inference/images (ImageInferenceReq) returns (ImageInferenceResp) diff --git a/api/internal/handler/inference/texttotextinferencehandler.go b/api/internal/handler/inference/texttotextinferencehandler.go new file mode 100644 index 00000000..55df289c --- /dev/null +++ b/api/internal/handler/inference/texttotextinferencehandler.go @@ -0,0 +1,28 @@ +package inference + +import ( + "net/http" + + "github.com/zeromicro/go-zero/rest/httpx" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" +) + +func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req types.TextToTextInferenceReq + if err := httpx.Parse(r, &req); err != nil { + httpx.ErrorCtx(r.Context(), w, err) + return + } + + l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx) + resp, err := l.TextToTextInference(&req) + if err != nil { + httpx.ErrorCtx(r.Context(), w, err) + } else { + httpx.OkJsonCtx(r.Context(), w, resp) + } + } +} diff --git a/api/internal/handler/routes.go b/api/internal/handler/routes.go index e04c7750..2daff30f 100644 --- a/api/internal/handler/routes.go +++ b/api/internal/handler/routes.go @@ -1138,6 +1138,11 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { server.AddRoutes( []rest.Route{ + { + Method: http.MethodPost, + Path: "/inference/text", + Handler: inference.TextToTextInferenceHandler(serverCtx), + }, { Method: http.MethodPost, Path: "/inference/images", diff --git a/api/internal/logic/inference/texttotextinferencelogic.go b/api/internal/logic/inference/texttotextinferencelogic.go new file mode 100644 index 00000000..faf61301 --- /dev/null +++ b/api/internal/logic/inference/texttotextinferencelogic.go @@ -0,0 +1,30 @@ +package inference + +import ( + "context" + + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + + "github.com/zeromicro/go-zero/core/logx" +) + +type TextToTextInferenceLogic struct { + logx.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewTextToTextInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TextToTextInferenceLogic { + return &TextToTextInferenceLogic{ + Logger: logx.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) { + // todo: add your logic here and delete this line + + return +} diff --git a/api/internal/types/types.go b/api/internal/types/types.go index 86f165bd..b2a942da 100644 --- a/api/internal/types/types.go +++ b/api/internal/types/types.go @@ -5941,3 +5941,15 @@ type InferenceResult struct { Card string `json:"card"` ClusterName string `json:"clusterName"` } + +type TextToTextInferenceReq struct { + TaskName string `form:"taskName"` + TaskDesc string `form:"taskDesc"` + ModelName string `form:"modelName"` + ModelType string `form:"modelType"` + AdapterId string `form:"adapterId"` + ClusterId string `form:"clusterIds"` +} + +type TextToTextInferenceResp struct { +} From c9084ff0936ec4d08c59e6da7102efa13e4c4294 Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 17:12:37 +0800 Subject: [PATCH 06/21] added textinference api Former-commit-id: ac9509e26c6a33b5caf86b8a2e6ff19b1e8ae0f7 --- api/desc/inference/inference.api | 2 +- api/internal/types/types.go | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/api/desc/inference/inference.api b/api/desc/inference/inference.api index 1872d2c6..f2e88943 100644 --- a/api/desc/inference/inference.api +++ b/api/desc/inference/inference.api @@ -70,7 +70,7 @@ type ( ModelName string `form:"modelName"` ModelType string `form:"modelType"` AdapterId string `form:"adapterId"` - ClusterId string `form:"clusterIds"` + AiClusterIds []string `form:"aiClusterIds"` } TextToTextInferenceResp{ diff --git a/api/internal/types/types.go b/api/internal/types/types.go index b2a942da..7ab16536 100644 --- a/api/internal/types/types.go +++ b/api/internal/types/types.go @@ -5943,12 +5943,12 @@ type InferenceResult struct { } type TextToTextInferenceReq struct { - TaskName string `form:"taskName"` - TaskDesc string `form:"taskDesc"` - ModelName string `form:"modelName"` - ModelType string `form:"modelType"` - AdapterId string `form:"adapterId"` - ClusterId string `form:"clusterIds"` + TaskName string `form:"taskName"` + TaskDesc string `form:"taskDesc"` + ModelName string `form:"modelName"` + ModelType string `form:"modelType"` + AdapterId string `form:"adapterId"` + AiClusterIds []string `form:"aiClusterIds"` } type TextToTextInferenceResp struct { From e6b9d3d23b521a10623426436f0db40f2ee4457a Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 18:19:18 +0800 Subject: [PATCH 07/21] added textinfer api Former-commit-id: bfdce9025139d71bc3178b039756d579be1b450c --- .../inference/texttotextinferencehandler.go | 13 +- .../logic/inference/imageinferencelogic.go | 14 +- .../inference/texttotextinferencelogic.go | 120 +++++++++++++++++- .../scheduler/service/collector/collector.go | 4 +- api/internal/storeLink/modelarts.go | 6 +- api/internal/storeLink/octopus.go | 6 +- api/internal/storeLink/shuguangai.go | 6 +- api/internal/storeLink/storeLink.go | 1 + pkg/models/taskaimodel_gen.go | 1 + 9 files changed, 140 insertions(+), 31 deletions(-) diff --git a/api/internal/handler/inference/texttotextinferencehandler.go b/api/internal/handler/inference/texttotextinferencehandler.go index 55df289c..d291d60e 100644 --- a/api/internal/handler/inference/texttotextinferencehandler.go +++ b/api/internal/handler/inference/texttotextinferencehandler.go @@ -1,28 +1,25 @@ package inference import ( - "net/http" - "github.com/zeromicro/go-zero/rest/httpx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" + "net/http" ) func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.TextToTextInferenceReq if err := httpx.Parse(r, &req); err != nil { - httpx.ErrorCtx(r.Context(), w, err) + result.ParamErrorResult(r, w, err) return } l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx) resp, err := l.TextToTextInference(&req) - if err != nil { - httpx.ErrorCtx(r.Context(), w, err) - } else { - httpx.OkJsonCtx(r.Context(), w, resp) - } + result.HttpResult(r, w, resp, err) + } } diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index ec40d60d..48aaf97a 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -133,14 +133,14 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s var wg sync.WaitGroup var cluster_ch = make(chan struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 }, len(clusters)) var cs []struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 @@ -182,7 +182,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s wg.Add(1) c := cluster go func() { - imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) + imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt) if err != nil { wg.Done() return @@ -190,7 +190,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) s := struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 @@ -373,7 +373,7 @@ func sendInferReq(images []struct { imageResult *types.ImageResult file multipart.File }, cluster struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 @@ -384,7 +384,7 @@ func sendInferReq(images []struct { imageResult *types.ImageResult file multipart.File }, c struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 @@ -494,7 +494,7 @@ type Res struct { } func contains(cs []struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 diff --git a/api/internal/logic/inference/texttotextinferencelogic.go b/api/internal/logic/inference/texttotextinferencelogic.go index faf61301..2974d834 100644 --- a/api/internal/logic/inference/texttotextinferencelogic.go +++ b/api/internal/logic/inference/texttotextinferencelogic.go @@ -2,11 +2,18 @@ package inference import ( "context" - + "errors" + "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" - - "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" + "strconv" + "sync" + "time" ) type TextToTextInferenceLogic struct { @@ -24,7 +31,110 @@ func NewTextToTextInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext } func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) { - // todo: add your logic here and delete this line + resp = &types.TextToTextInferenceResp{} + opt := &option.InferOption{ + TaskName: req.TaskName, + TaskDesc: req.TaskDesc, + AdapterId: req.AdapterId, + AiClusterIds: req.AiClusterIds, + ModelName: req.ModelName, + ModelType: req.ModelType, + } - return + _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] + if !ok { + return nil, errors.New("AdapterId does not exist") + } + + //save task + var synergystatus int64 + var strategyCode int64 + adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) + if err != nil { + return nil, err + } + + id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12") + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + var cluster_ch = make(chan struct { + urls []*collector.InferUrl + clusterId string + clusterName string + }, len(opt.AiClusterIds)) + + var cs []struct { + urls []*collector.InferUrl + clusterId string + clusterName string + } + collectorMap := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] + + //save taskai + for _, clusterId := range opt.AiClusterIds { + wg.Add(1) + go func(cId string) { + urls, err := collectorMap[cId].GetInferUrl(l.ctx, opt) + if err != nil { + wg.Done() + return + } + clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId) + + s := struct { + urls []*collector.InferUrl + clusterId string + clusterName string + }{ + urls: urls, + clusterId: cId, + clusterName: clusterName, + } + + cluster_ch <- s + wg.Done() + return + }(clusterId) + } + wg.Wait() + close(cluster_ch) + + for s := range cluster_ch { + cs = append(cs, s) + } + + for _, c := range cs { + clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId) + err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "") + if err != nil { + return nil, err + } + } + + var aiTaskList []*models.TaskAi + tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) + if tx.Error != nil { + return nil, tx.Error + + } + + for i, t := range aiTaskList { + if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId { + t.Status = constants.Completed + t.EndTime = time.Now().Format(time.RFC3339) + url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat" + t.InferUrl = url + err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t) + if err != nil { + logx.Errorf(tx.Error.Error()) + } + } + } + + l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") + + return resp, nil } diff --git a/api/internal/scheduler/service/collector/collector.go b/api/internal/scheduler/service/collector/collector.go index 1fa777ed..4c978628 100644 --- a/api/internal/scheduler/service/collector/collector.go +++ b/api/internal/scheduler/service/collector/collector.go @@ -15,10 +15,10 @@ type AiCollector interface { UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error GetComputeCards(ctx context.Context) ([]string, error) GetUserBalance(ctx context.Context) (float64, error) - GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*ImageInferUrl, error) + GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error) } -type ImageInferUrl struct { +type InferUrl struct { Url string Card string } diff --git a/api/internal/storeLink/modelarts.go b/api/internal/storeLink/modelarts.go index a2532730..331becb1 100644 --- a/api/internal/storeLink/modelarts.go +++ b/api/internal/storeLink/modelarts.go @@ -378,8 +378,8 @@ func (m *ModelArtsLink) generateAlgorithmId(ctx context.Context, option *option. return errors.New("failed to get AlgorithmId") } -func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { - var imageUrls []*collector.ImageInferUrl +func (m *ModelArtsLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) { + var imageUrls []*collector.InferUrl urlReq := &modelartsclient.ImageReasoningUrlReq{ ModelName: option.ModelName, Type: option.ModelType, @@ -389,7 +389,7 @@ func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.Inf if err != nil { return nil, err } - imageUrl := &collector.ImageInferUrl{ + imageUrl := &collector.InferUrl{ Url: urlResp.Url, Card: "npu", } diff --git a/api/internal/storeLink/octopus.go b/api/internal/storeLink/octopus.go index a902cf93..9c605345 100644 --- a/api/internal/storeLink/octopus.go +++ b/api/internal/storeLink/octopus.go @@ -871,7 +871,7 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec return errors.New("set ResourceId error") } -func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { +func (o *OctopusLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) { req := &octopus.GetNotebookListReq{ Platform: o.platform, PageIndex: o.pageIndex, @@ -882,12 +882,12 @@ func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.Infer return nil, err } - var imageUrls []*collector.ImageInferUrl + var imageUrls []*collector.InferUrl for _, notebook := range list.Payload.GetNotebooks() { if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" { url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1) names := strings.Split(notebook.AlgorithmName, UNDERSCORE) - imageUrl := &collector.ImageInferUrl{ + imageUrl := &collector.InferUrl{ Url: DOMAIN + url + FORWARD_SLASH + "image", Card: names[2], } diff --git a/api/internal/storeLink/shuguangai.go b/api/internal/storeLink/shuguangai.go index c0ca6d93..2a634a4c 100644 --- a/api/internal/storeLink/shuguangai.go +++ b/api/internal/storeLink/shuguangai.go @@ -730,8 +730,8 @@ func (s *ShuguangAi) generateParams(option *option.AiOption) error { return errors.New("failed to set params") } -func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { - var imageUrls []*collector.ImageInferUrl +func (s *ShuguangAi) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) { + var imageUrls []*collector.InferUrl urlReq := &hpcAC.GetInferUrlReq{ ModelName: option.ModelName, @@ -743,7 +743,7 @@ func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferO if err != nil { return nil, err } - imageUrl := &collector.ImageInferUrl{ + imageUrl := &collector.InferUrl{ Url: urlResp.Url, Card: "dcu", } diff --git a/api/internal/storeLink/storeLink.go b/api/internal/storeLink/storeLink.go index d5ead2e6..2cd54f59 100644 --- a/api/internal/storeLink/storeLink.go +++ b/api/internal/storeLink/storeLink.go @@ -78,6 +78,7 @@ var ( } ModelTypeMap = map[string][]string{ "image_recognition": {"imagenet_resnet50"}, + "text_to_text": {"chatGLM-6B"}, } AITYPE = map[string]string{ "1": OCTOPUS, diff --git a/pkg/models/taskaimodel_gen.go b/pkg/models/taskaimodel_gen.go index 6b0e07bc..24473016 100644 --- a/pkg/models/taskaimodel_gen.go +++ b/pkg/models/taskaimodel_gen.go @@ -54,6 +54,7 @@ type ( TaskType string `db:"task_type"` DeletedAt *time.Time `db:"deleted_at"` Card string `db:"card"` + InferUrl string `db:"infer_url"` } ) From 125c9a3ce3a47477b4565b390bf1f5a13a5b9a42 Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 18:42:42 +0800 Subject: [PATCH 08/21] updated textinfer api Former-commit-id: 462ab959adaae03bdfaa40adf53b361bf1978855 --- api/internal/storeLink/storeLink.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/internal/storeLink/storeLink.go b/api/internal/storeLink/storeLink.go index 2cd54f59..69cfc02d 100644 --- a/api/internal/storeLink/storeLink.go +++ b/api/internal/storeLink/storeLink.go @@ -78,7 +78,7 @@ var ( } ModelTypeMap = map[string][]string{ "image_recognition": {"imagenet_resnet50"}, - "text_to_text": {"chatGLM-6B"}, + "text_to_text": {"chatGLM_6B"}, } AITYPE = map[string]string{ "1": OCTOPUS, From fd0e82349020c892aad41b5bc42819b61f4c3181 Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 19:19:48 +0800 Subject: [PATCH 09/21] updated textinfer api Former-commit-id: d0c9203bfb6af1f141e9bf2468d35cc1ab642084 --- .../logic/inference/texttotextinferencelogic.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/api/internal/logic/inference/texttotextinferencelogic.go b/api/internal/logic/inference/texttotextinferencelogic.go index 2974d834..58b3edc1 100644 --- a/api/internal/logic/inference/texttotextinferencelogic.go +++ b/api/internal/logic/inference/texttotextinferencelogic.go @@ -106,6 +106,16 @@ func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInfe cs = append(cs, s) } + if len(cs) == 0 { + clusterId := opt.AiClusterIds[0] + clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(opt.AiClusterIds[0]) + err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, clusterId, clusterName, "", constants.Failed, "") + if err != nil { + return nil, err + } + l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") + } + for _, c := range cs { clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId) err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "") From 45b1171fb0d58b7856b554d2545e9b34dd3c1fc5 Mon Sep 17 00:00:00 2001 From: qiwang <1364512070@qq.com> Date: Tue, 25 Jun 2024 20:01:38 +0800 Subject: [PATCH 10/21] fix: update taskai 0625 Former-commit-id: 8f0176a0e31aab784995f4f4211b208b58e5c7ff --- api/internal/logic/inference/imageinferencelogic.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index ec40d60d..b803e442 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -460,18 +460,6 @@ func getInferResult(url string, file multipart.File, fileName string, clusterNam func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { var res Res - /* req := GetRestyRequest(20) - _, err := req. - SetFileReader("file", fileName, file). - SetHeaders(map[string]string{ - "ak": "UNEHPHO4Z7YSNPKRXFE4", - "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", - }). - SetResult(&res). - Post(url) - if err != nil { - return "", err - }*/ body, err := utils.SendRequest("POST", url, file, fileName) if err != nil { return "", err From 8428cc0ebb4f353e11f76d38ad98454a6624d8c8 Mon Sep 17 00:00:00 2001 From: jagger Date: Tue, 25 Jun 2024 20:25:47 +0800 Subject: [PATCH 11/21] fix bug Signed-off-by: jagger Former-commit-id: 04902084af3b14cf403f9cfad4eeb38e05ccee3d --- api/desc/ai/pcm-ai.api | 10 +++- api/desc/core/pcm-core.api | 9 ++++ api/desc/pcm.api | 6 +++ api/internal/handler/ai/proxyapihandler.go | 24 ++++++++++ api/internal/handler/routes.go | 5 ++ api/internal/logic/ai/proxyapilogic.go | 56 ++++++++++++++++++++++ api/internal/types/types.go | 13 +++++ 7 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 api/internal/handler/ai/proxyapihandler.go create mode 100644 api/internal/logic/ai/proxyapilogic.go diff --git a/api/desc/ai/pcm-ai.api b/api/desc/ai/pcm-ai.api index 4458d6c9..fefc7ea3 100644 --- a/api/desc/ai/pcm-ai.api +++ b/api/desc/ai/pcm-ai.api @@ -1818,4 +1818,12 @@ service AICore-api { get /getVisualizationJob (GetVisualizationJobReq) returns (GetVisualizationJobResp) @handler createVisualizationJobHandler post /CreateVisualizationJob (CreateVisualizationJobReq) returns (CreateVisualizationJobResp) -}*/ \ No newline at end of file +}*/ + +type ( + ChatReq{ + ApiUrl string `json:"apiUrl,optional"` + Method string `json:"method,optional"` + ReqData map[string]interface{} `json:"reqData"` + } +) \ No newline at end of file diff --git a/api/desc/core/pcm-core.api b/api/desc/core/pcm-core.api index 516409cc..0ad637b9 100644 --- a/api/desc/core/pcm-core.api +++ b/api/desc/core/pcm-core.api @@ -1256,5 +1256,14 @@ type ( ClusterName string `json:"clusterName" db:"cluster_name"` Status string `json:"status" db:"status"` Remark string `json:"remark" db:"remark"` + InferUrl string `json:"inferUrl"` + } +) + +type ( + CommonResp { + Code int `json:"code,omitempty"` + Msg string `json:"msg,omitempty"` + Data interface{} `json:"data,omitempty"` } ) \ No newline at end of file diff --git a/api/desc/pcm.api b/api/desc/pcm.api index 0a7b51bf..8ee57626 100644 --- a/api/desc/pcm.api +++ b/api/desc/pcm.api @@ -366,6 +366,12 @@ service pcm { @handler createVisualizationJobHandler post /ai/CreateVisualizationJob (CreateVisualizationJobReq) returns (CreateVisualizationJobResp) /******************Visualization Job Method start*************************/ + + /***********chat***********/ + @doc "文本识别" + @handler ProxyApiHandler + post /ai/chat (ChatReq) returns (CommonResp) + /******chat end***********/ } //screen接口 diff --git a/api/internal/handler/ai/proxyapihandler.go b/api/internal/handler/ai/proxyapihandler.go new file mode 100644 index 00000000..cbb732e2 --- /dev/null +++ b/api/internal/handler/ai/proxyapihandler.go @@ -0,0 +1,24 @@ +package ai + +import ( + "github.com/zeromicro/go-zero/rest/httpx" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/ai" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" + "net/http" +) + +func ProxyApiHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req types.ChatReq + if err := httpx.Parse(r, &req); err != nil { + result.ParamErrorResult(r, w, err) + return + } + + l := ai.NewProxyApiLogic(r.Context(), svcCtx) + resp, err := l.ProxyApi(&req, w) + result.HttpResult(r, w, resp, err) + } +} diff --git a/api/internal/handler/routes.go b/api/internal/handler/routes.go index 2daff30f..00a8e345 100644 --- a/api/internal/handler/routes.go +++ b/api/internal/handler/routes.go @@ -437,6 +437,11 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { Path: "/ai/CreateVisualizationJob", Handler: ai.CreateVisualizationJobHandler(serverCtx), }, + { + Method: http.MethodPost, + Path: "/ai/chat", + Handler: ai.ProxyApiHandler(serverCtx), + }, }, rest.WithPrefix("/pcm/v1"), ) diff --git a/api/internal/logic/ai/proxyapilogic.go b/api/internal/logic/ai/proxyapilogic.go new file mode 100644 index 00000000..cf201644 --- /dev/null +++ b/api/internal/logic/ai/proxyapilogic.go @@ -0,0 +1,56 @@ +package ai + +import ( + "bytes" + "context" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" + tool "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" + "k8s.io/apimachinery/pkg/util/json" + "net/http" + + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + + "github.com/zeromicro/go-zero/core/logx" +) + +type ProxyApiLogic struct { + logx.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewProxyApiLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ProxyApiLogic { + return &ProxyApiLogic{ + Logger: logx.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +type ChatResult struct { + Results string `json:"results"` +} + +func (l *ProxyApiLogic) ProxyApi(req *types.ChatReq, w http.ResponseWriter) (resp *types.CommonResp, err error) { + jsonBytes, err := json.Marshal(&req.ReqData) + // 调用第三方接口的 POST 方法 + respThirdParty, err := http.Post(req.ApiUrl, "application/json", bytes.NewBuffer(jsonBytes)) + if err != nil { + return + } + defer respThirdParty.Body.Close() + marshal, err := json.Marshal(&respThirdParty.Body) + if err != nil { + return nil, result.NewDefaultError(err.Error()) + } + json.Unmarshal(marshal, &resp) + + chatResult := &ChatResult{} + tool.Convert(resp, &chatResult.Results) + return &types.CommonResp{ + Code: respThirdParty.StatusCode, + Msg: "success", + Data: chatResult, + }, nil +} diff --git a/api/internal/types/types.go b/api/internal/types/types.go index 7ab16536..a549d98a 100644 --- a/api/internal/types/types.go +++ b/api/internal/types/types.go @@ -1179,6 +1179,13 @@ type SubTaskInfo struct { ClusterName string `json:"clusterName" db:"cluster_name"` Status string `json:"status" db:"status"` Remark string `json:"remark" db:"remark"` + InferUrl string `json:"inferUrl"` +} + +type CommonResp struct { + Code int `json:"code,omitempty"` + Msg string `json:"msg,omitempty"` + Data interface{} `json:"data,omitempty"` } type CommitHpcTaskReq struct { @@ -2869,6 +2876,12 @@ type AiTask struct { TimeElapsed int32 `json:"elapsed,optional"` } +type ChatReq struct { + ApiUrl string `json:"apiUrl,optional"` + Method string `json:"method,optional"` + ReqData map[string]interface{} `json:"reqData"` +} + type StorageScreenReq struct { } From 62923ca197a50604df1e301815e2c9860ed5dba5 Mon Sep 17 00:00:00 2001 From: qiwang <1364512070@qq.com> Date: Tue, 25 Jun 2024 20:27:54 +0800 Subject: [PATCH 12/21] fix: add iamge log 0625 Former-commit-id: c48e976f5bafd53bc153c36421374f2f382478dc --- api/internal/logic/inference/imageinferencelogic.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index 151136bd..69732457 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -461,6 +461,7 @@ func getInferResult(url string, file multipart.File, fileName string, clusterNam func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { var res Res body, err := utils.SendRequest("POST", url, file, fileName) + log.Fatalf("图形识别url: %s", url) if err != nil { return "", err } @@ -468,6 +469,7 @@ func getInferResultModelarts(url string, file multipart.File, fileName string) ( if errjson != nil { log.Fatalf("Error parsing JSON: %s", errjson) } + log.Fatalf("推理结果: %s", res.Result) return res.Result, nil } From 7cf64232706319e1e221aa5f8eceb87483a3fdbd Mon Sep 17 00:00:00 2001 From: qiwang <1364512070@qq.com> Date: Tue, 25 Jun 2024 20:41:19 +0800 Subject: [PATCH 13/21] fix: add image log 0625 Former-commit-id: 513de69ad86c3c00f985dffedd8e0e1ad8b19344 --- api/internal/logic/inference/imageinferencelogic.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index 69732457..fcb20376 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -441,6 +441,7 @@ func sendInferReq(images []struct { func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { if clusterName == "鹏城云脑II-modelarts" { r, err := getInferResultModelarts(url, file, fileName) + log.Printf("图形识别url: %s", url) if err != nil { return "", err } @@ -461,7 +462,7 @@ func getInferResult(url string, file multipart.File, fileName string, clusterNam func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { var res Res body, err := utils.SendRequest("POST", url, file, fileName) - log.Fatalf("图形识别url: %s", url) + log.Printf("图形识别url: %s", url) if err != nil { return "", err } @@ -469,7 +470,7 @@ func getInferResultModelarts(url string, file multipart.File, fileName string) ( if errjson != nil { log.Fatalf("Error parsing JSON: %s", errjson) } - log.Fatalf("推理结果: %s", res.Result) + log.Printf("推理结果: %s", res.Result) return res.Result, nil } From 127c10de85e77096b071e881fbecebde247d2f6f Mon Sep 17 00:00:00 2001 From: jagger Date: Tue, 25 Jun 2024 20:53:22 +0800 Subject: [PATCH 14/21] fix bug Signed-off-by: jagger Former-commit-id: b7dc16441702ba2da13d8a55ac2c71f268045dfb --- api/internal/logic/ai/proxyapilogic.go | 53 +++++-- pkg/utils/hws/escape.go | 38 +++++ pkg/utils/hws/signer.go | 184 +++++++++++++++++++++++++ 3 files changed, 265 insertions(+), 10 deletions(-) create mode 100644 pkg/utils/hws/escape.go create mode 100644 pkg/utils/hws/signer.go diff --git a/api/internal/logic/ai/proxyapilogic.go b/api/internal/logic/ai/proxyapilogic.go index cf201644..1e7235b0 100644 --- a/api/internal/logic/ai/proxyapilogic.go +++ b/api/internal/logic/ai/proxyapilogic.go @@ -3,9 +3,11 @@ package ai import ( "bytes" "context" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" + "crypto/tls" + "encoding/json" + "fmt" tool "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" - "k8s.io/apimachinery/pkg/util/json" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils/hws" "net/http" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" @@ -32,24 +34,55 @@ type ChatResult struct { Results string `json:"results"` } +type ResponseData struct { + Results string `json:"results"` +} + func (l *ProxyApiLogic) ProxyApi(req *types.ChatReq, w http.ResponseWriter) (resp *types.CommonResp, err error) { + jsonBytes, err := json.Marshal(&req.ReqData) // 调用第三方接口的 POST 方法 - respThirdParty, err := http.Post(req.ApiUrl, "application/json", bytes.NewBuffer(jsonBytes)) + thirdReq, err := http.NewRequest("POST", req.ApiUrl, bytes.NewBuffer(jsonBytes)) if err != nil { return } - defer respThirdParty.Body.Close() - marshal, err := json.Marshal(&respThirdParty.Body) - if err != nil { - return nil, result.NewDefaultError(err.Error()) + + signer := &hws.Signer{ + Key: "UNEHPHO4Z7YSNPKRXFE4", + Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", + } + + if err := signer.Sign(thirdReq); err != nil { + return nil, err + } + + // 设置client信任所有证书 + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{ + Transport: tr, + } + + thirdReq.Header.Set("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52") + thirdReq.Header.Set("x-stage", "RELEASE") + thirdReq.Header.Set("Authorization", thirdReq.Header.Get(hws.HeaderXAuthorization)) + thirdReq.Header.Set("X-Sdk-Date", thirdReq.Header.Get(hws.HeaderXDateTime)) + thirdReq.Header.Set("Content-Type", "application/json") + + thirdResp, err := client.Do(thirdReq) + + defer thirdReq.Body.Close() + var responseData ResponseData + decoder := json.NewDecoder(thirdResp.Body) + if err := decoder.Decode(&responseData); err != nil { + fmt.Println("Error decoding response:", err) } - json.Unmarshal(marshal, &resp) chatResult := &ChatResult{} - tool.Convert(resp, &chatResult.Results) + tool.Convert(responseData, &chatResult) return &types.CommonResp{ - Code: respThirdParty.StatusCode, + Code: thirdResp.StatusCode, Msg: "success", Data: chatResult, }, nil diff --git a/pkg/utils/hws/escape.go b/pkg/utils/hws/escape.go new file mode 100644 index 00000000..5c311ce4 --- /dev/null +++ b/pkg/utils/hws/escape.go @@ -0,0 +1,38 @@ +package hws + +func shouldEscape(c byte) bool { + if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' || c == '_' || c == '-' || c == '~' || c == '.' { + return false + } + return true +} + +func Escape(s string) string { + hexCount := 0 + for i := 0; i < len(s); i++ { + c := s[i] + if shouldEscape(c) { + hexCount++ + } + } + + if hexCount == 0 { + return s + } + + t := make([]byte, len(s)+2*hexCount) + j := 0 + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case shouldEscape(c): + t[j] = '%' + t[j+1] = "0123456789ABCDEF"[c>>4] + t[j+2] = "0123456789ABCDEF"[c&15] + j += 3 + default: + t[j] = s[i] + j++ + } + } + return string(t) +} diff --git a/pkg/utils/hws/signer.go b/pkg/utils/hws/signer.go new file mode 100644 index 00000000..1455c82a --- /dev/null +++ b/pkg/utils/hws/signer.go @@ -0,0 +1,184 @@ +package hws + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "fmt" + "io/ioutil" + "net/http" + "sort" + "strings" + "time" +) + +const ( + DateFormat = "20060102T150405Z" + SignAlgorithm = "SDK-HMAC-SHA256" + HeaderXDateTime = "X-Sdk-Date" + HeaderXHost = "host" + HeaderXAuthorization = "Authorization" + HeaderXContentSha256 = "X-Sdk-Content-Sha256" +) + +func hmacsha256(keyByte []byte, dataStr string) ([]byte, error) { + hm := hmac.New(sha256.New, []byte(keyByte)) + if _, err := hm.Write([]byte(dataStr)); err != nil { + return nil, err + } + return hm.Sum(nil), nil +} + +func CanonicalRequest(request *http.Request, signedHeaders []string) (string, error) { + var hexencode string + var err error + if hex := request.Header.Get(HeaderXContentSha256); hex != "" { + hexencode = hex + } else { + bodyData, err := RequestPayload(request) + if err != nil { + return "", err + } + hexencode, err = HexEncodeSHA256Hash(bodyData) + if err != nil { + return "", err + } + } + return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", request.Method, CanonicalURI(request), CanonicalQueryString(request), CanonicalHeaders(request, signedHeaders), strings.Join(signedHeaders, ";"), hexencode), err +} + +func CanonicalURI(request *http.Request) string { + pattens := strings.Split(request.URL.Path, "/") + var uriSlice []string + for _, v := range pattens { + uriSlice = append(uriSlice, Escape(v)) + } + urlpath := strings.Join(uriSlice, "/") + if len(urlpath) == 0 || urlpath[len(urlpath)-1] != '/' { + urlpath = urlpath + "/" + } + return urlpath +} + +func CanonicalQueryString(request *http.Request) string { + var keys []string + queryMap := request.URL.Query() + for key := range queryMap { + keys = append(keys, key) + } + sort.Strings(keys) + var query []string + for _, key := range keys { + k := Escape(key) + sort.Strings(queryMap[key]) + for _, v := range queryMap[key] { + kv := fmt.Sprintf("%s=%s", k, Escape(v)) + query = append(query, kv) + } + } + queryStr := strings.Join(query, "&") + request.URL.RawQuery = queryStr + return queryStr +} + +func CanonicalHeaders(request *http.Request, signerHeaders []string) string { + var canonicalHeaders []string + header := make(map[string][]string) + for k, v := range request.Header { + header[strings.ToLower(k)] = v + } + for _, key := range signerHeaders { + value := header[key] + if strings.EqualFold(key, HeaderXHost) { + value = []string{request.Host} + } + sort.Strings(value) + for _, v := range value { + canonicalHeaders = append(canonicalHeaders, key+":"+strings.TrimSpace(v)) + } + } + return fmt.Sprintf("%s\n", strings.Join(canonicalHeaders, "\n")) +} + +func SignedHeaders(r *http.Request) []string { + var signedHeaders []string + for key := range r.Header { + signedHeaders = append(signedHeaders, strings.ToLower(key)) + } + sort.Strings(signedHeaders) + return signedHeaders +} + +func RequestPayload(request *http.Request) ([]byte, error) { + if request.Body == nil { + return []byte(""), nil + } + bodyByte, err := ioutil.ReadAll(request.Body) + if err != nil { + return []byte(""), err + } + request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyByte)) + return bodyByte, err +} + +func StringToSign(canonicalRequest string, t time.Time) (string, error) { + hashStruct := sha256.New() + _, err := hashStruct.Write([]byte(canonicalRequest)) + if err != nil { + return "", err + } + return fmt.Sprintf("%s\n%s\n%x", + SignAlgorithm, t.UTC().Format(DateFormat), hashStruct.Sum(nil)), nil +} + +func SignStringToSign(stringToSign string, signingKey []byte) (string, error) { + hmsha, err := hmacsha256(signingKey, stringToSign) + return fmt.Sprintf("%x", hmsha), err +} + +func HexEncodeSHA256Hash(body []byte) (string, error) { + hashStruct := sha256.New() + if len(body) == 0 { + body = []byte("") + } + _, err := hashStruct.Write(body) + return fmt.Sprintf("%x", hashStruct.Sum(nil)), err +} + +func AuthHeaderValue(signatureStr, accessKeyStr string, signedHeaders []string) string { + return fmt.Sprintf("%s Access=%s, SignedHeaders=%s, Signature=%s", SignAlgorithm, accessKeyStr, strings.Join(signedHeaders, ";"), signatureStr) +} + +type Signer struct { + Key string + Secret string +} + +func (s *Signer) Sign(request *http.Request) error { + var t time.Time + var err error + var date string + if date = request.Header.Get(HeaderXDateTime); date != "" { + t, err = time.Parse(DateFormat, date) + } + if err != nil || date == "" { + t = time.Now() + request.Header.Set(HeaderXDateTime, t.UTC().Format(DateFormat)) + } + signedHeaders := SignedHeaders(request) + canonicalRequest, err := CanonicalRequest(request, signedHeaders) + if err != nil { + return err + } + stringToSignStr, err := StringToSign(canonicalRequest, t) + if err != nil { + return err + } + signatureStr, err := SignStringToSign(stringToSignStr, []byte(s.Secret)) + if err != nil { + return err + } + authValueStr := AuthHeaderValue(signatureStr, s.Key, signedHeaders) + request.Header.Set(HeaderXAuthorization, authValueStr) + return nil +} From c0b64237badf595a62625ec1e50bf71ab01f7575 Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 21:33:28 +0800 Subject: [PATCH 15/21] updated textinfer api Former-commit-id: 0ae4e41fae1f4bd8ee8d8c6185cd43e38c3d6068 --- api/internal/logic/inference/imageinferencelogic.go | 4 ++++ api/internal/storeLink/octopus.go | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index 48aaf97a..25036805 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -8,6 +8,7 @@ import ( "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" @@ -183,6 +184,9 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s c := cluster go func() { imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt) + for i, _ := range imageUrls { + imageUrls[i].Url = imageUrls[i].Url + storeLink.FORWARD_SLASH + "image" + } if err != nil { wg.Done() return diff --git a/api/internal/storeLink/octopus.go b/api/internal/storeLink/octopus.go index 9c605345..ed706c0a 100644 --- a/api/internal/storeLink/octopus.go +++ b/api/internal/storeLink/octopus.go @@ -888,7 +888,7 @@ func (o *OctopusLink) GetInferUrl(ctx context.Context, option *option.InferOptio url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1) names := strings.Split(notebook.AlgorithmName, UNDERSCORE) imageUrl := &collector.InferUrl{ - Url: DOMAIN + url + FORWARD_SLASH + "image", + Url: DOMAIN + url, Card: names[2], } imageUrls = append(imageUrls, imageUrl) From 27dd00af1ca3b16fbc1dd770487f9d9ca8a89d3c Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 22:07:11 +0800 Subject: [PATCH 16/21] fix crontask bug Former-commit-id: 545109fd9289b9401fd3daddac2f018b5961a331 --- api/internal/cron/aiCronTask.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/internal/cron/aiCronTask.go b/api/internal/cron/aiCronTask.go index 3b14cf3c..4ad76deb 100644 --- a/api/internal/cron/aiCronTask.go +++ b/api/internal/cron/aiCronTask.go @@ -561,6 +561,10 @@ func UpdateInferTaskStatus(svc *svc.ServiceContext, task *types.TaskModel) { // return //} + if aiTask[0].StartTime == "" { + return + } + start, _ := time.ParseInLocation(time.RFC3339, aiTask[0].StartTime, time.Local) end, _ := time.ParseInLocation(time.RFC3339, aiTask[0].EndTime, time.Local) var status string From c2f5349f5c546334b19be8480af4d066b519b237 Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 22:55:12 +0800 Subject: [PATCH 17/21] fix tasklist bugs Former-commit-id: f97b9f3eac995728af77aff0bb174d7a3449a858 --- api/internal/logic/core/pagelisttasklogic.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/api/internal/logic/core/pagelisttasklogic.go b/api/internal/logic/core/pagelisttasklogic.go index 52ad4171..29e5d000 100644 --- a/api/internal/logic/core/pagelisttasklogic.go +++ b/api/internal/logic/core/pagelisttasklogic.go @@ -57,9 +57,9 @@ func (l *PageListTaskLogic) PageListTask(req *types.PageTaskReq) (resp *types.Pa } // 更新智算任务状态 - chs := [2]chan struct{}{make(chan struct{}), make(chan struct{})} - go l.updateTaskStatus(list, chs[0]) - go l.updateAiTaskStatus(list, chs[1]) + //chs := [2]chan struct{}{make(chan struct{}), make(chan struct{})} + //go l.updateTaskStatus(list, chs[0]) + //go l.updateAiTaskStatus(list, chs[1]) for _, model := range list { if model.StartTime != "" && model.EndTime == "" { @@ -77,12 +77,12 @@ func (l *PageListTaskLogic) PageListTask(req *types.PageTaskReq) (resp *types.Pa resp.PageNum = req.PageNum resp.Total = total - for _, ch := range chs { - select { - case <-ch: - case <-time.After(1 * time.Second): - } - } + //for _, ch := range chs { + // select { + // case <-ch: + // case <-time.After(1 * time.Second): + // } + //} return } From 504b3e95d9e90ff0441e777fae78485d7723cd43 Mon Sep 17 00:00:00 2001 From: tzwang Date: Wed, 26 Jun 2024 09:28:08 +0800 Subject: [PATCH 18/21] fix crontask bug Former-commit-id: 0a317a3ff1e27eb3d9a774bb0ba9536248fa3586 --- api/internal/cron/aiCronTask.go | 37 +++++++-------------------------- 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/api/internal/cron/aiCronTask.go b/api/internal/cron/aiCronTask.go index 4ad76deb..254a16e6 100644 --- a/api/internal/cron/aiCronTask.go +++ b/api/internal/cron/aiCronTask.go @@ -174,7 +174,7 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { } // Update Infer Task Status - if task.TaskTypeDict == 11 { + if task.TaskTypeDict == 11 || task.TaskTypeDict == 12 { UpdateInferTaskStatus(svc, task) return } @@ -229,30 +229,14 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { return } - var start time.Time - var end time.Time - // distinguish train or infer temporarily - if task.TaskTypeDict == 11 { - start, _ = time.ParseInLocation(time.RFC3339, aiTask[0].StartTime, time.Local) - end, _ = time.ParseInLocation(time.RFC3339, aiTask[0].EndTime, time.Local) - } else { - start, _ = time.ParseInLocation(constants.Layout, aiTask[0].StartTime, time.Local) - end, _ = time.ParseInLocation(constants.Layout, aiTask[0].EndTime, time.Local) - } + start, _ := time.ParseInLocation(constants.Layout, aiTask[0].StartTime, time.Local) + end, _ := time.ParseInLocation(constants.Layout, aiTask[0].EndTime, time.Local) var status string var count int for _, a := range aiTask { - var s time.Time - var e time.Time - // distinguish train or infer temporarily - if task.TaskTypeDict == 11 { - s, _ = time.ParseInLocation(time.RFC3339, a.StartTime, time.Local) - e, _ = time.ParseInLocation(time.RFC3339, a.EndTime, time.Local) - } else { - s, _ = time.ParseInLocation(constants.Layout, a.StartTime, time.Local) - e, _ = time.ParseInLocation(constants.Layout, a.EndTime, time.Local) - } + s, _ := time.ParseInLocation(constants.Layout, a.StartTime, time.Local) + e, _ := time.ParseInLocation(constants.Layout, a.EndTime, time.Local) if s.Before(start) { start = s @@ -289,15 +273,8 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { if status != "" { task.Status = status - // distinguish train or infer temporarily - if task.TaskTypeDict == 11 { - task.StartTime = start.Format(time.RFC3339) - task.EndTime = end.Format(time.RFC3339) - } else { - task.StartTime = start.Format(constants.Layout) - task.EndTime = end.Format(constants.Layout) - } - + task.StartTime = start.Format(constants.Layout) + task.EndTime = end.Format(constants.Layout) } task.UpdatedTime = time.Now().Format(constants.Layout) From bee190493f0a577e9ed05d34ad72a44f22c57a3d Mon Sep 17 00:00:00 2001 From: tzwang Date: Wed, 26 Jun 2024 16:31:09 +0800 Subject: [PATCH 19/21] updated taskmodel api type Former-commit-id: 8d89e912f552e5f3ca2390610186905b2886578d --- api/desc/core/pcm-core.api | 4 ++-- api/internal/types/types.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/desc/core/pcm-core.api b/api/desc/core/pcm-core.api index 516409cc..f8fb1ac7 100644 --- a/api/desc/core/pcm-core.api +++ b/api/desc/core/pcm-core.api @@ -404,8 +404,8 @@ type ( TenantId string `json:"tenantId,omitempty" db:"tenant_id"` CreatedTime string `json:"createdTime,omitempty" db:"created_time" gorm:"autoCreateTime"` UpdatedTime string `json:"updatedTime,omitempty" db:"updated_time"` - AdapterTypeDict int `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值 - TaskTypeDict int `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值 + AdapterTypeDict string `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值 + TaskTypeDict string `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值 } ) diff --git a/api/internal/types/types.go b/api/internal/types/types.go index 7ab16536..0b53492c 100644 --- a/api/internal/types/types.go +++ b/api/internal/types/types.go @@ -347,8 +347,8 @@ type TaskModel struct { TenantId string `json:"tenantId,omitempty" db:"tenant_id"` CreatedTime string `json:"createdTime,omitempty" db:"created_time" gorm:"autoCreateTime"` UpdatedTime string `json:"updatedTime,omitempty" db:"updated_time"` - AdapterTypeDict int `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值 - TaskTypeDict int `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值 + AdapterTypeDict string `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值 + TaskTypeDict string `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值 } type TaskDetailReq struct { From 02d5da24f031a319c9b99c4cba0423b072c7c603 Mon Sep 17 00:00:00 2001 From: qiwang <1364512070@qq.com> Date: Wed, 26 Jun 2024 16:58:52 +0800 Subject: [PATCH 20/21] fix: update vminfo 0626 Former-commit-id: d3d9e2069595db59945f20b3cd422fb17bdbd71f --- api/client/types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/client/types.go b/api/client/types.go index bf5c676d..2aa28e22 100644 --- a/api/client/types.go +++ b/api/client/types.go @@ -175,7 +175,7 @@ type VmInfo struct { //DeletedAt string `json:"deletedAt,omitempty"` VmName string `json:"vmName,omitempty"` Replicas int64 `json:"replicas,omitempty"` - ServerId string `json:"serverId,omitempty"` + //ServerId string `json:"serverId,omitempty"` } type ResourceStats struct { From 28e9deea2c3c6a34ec5c20ebc7716e04e5bf6055 Mon Sep 17 00:00:00 2001 From: tzwang Date: Wed, 26 Jun 2024 18:35:15 +0800 Subject: [PATCH 21/21] fix imageinfer api bug Former-commit-id: 1b914196fd8e5069d25c4e39994720d501cd38f5 --- api/internal/cron/aiCronTask.go | 6 +- api/internal/logic/core/pagelisttasklogic.go | 4 +- .../logic/inference/imageinferencelogic.go | 406 +----------------- api/internal/scheduler/database/aiStorage.go | 17 + api/internal/scheduler/service/aiService.go | 14 + .../scheduler/service/inference/imageInfer.go | 385 +++++++++++++++++ 6 files changed, 433 insertions(+), 399 deletions(-) create mode 100644 api/internal/scheduler/service/inference/imageInfer.go diff --git a/api/internal/cron/aiCronTask.go b/api/internal/cron/aiCronTask.go index 254a16e6..fdf2a767 100644 --- a/api/internal/cron/aiCronTask.go +++ b/api/internal/cron/aiCronTask.go @@ -58,7 +58,7 @@ func UpdateAiTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { list := make([]*types.TaskModel, len(tasklist)) copy(list, tasklist) for i := len(list) - 1; i >= 0; i-- { - if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { list = append(list[:i], list[i+1:]...) } } @@ -155,7 +155,7 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { list := make([]*types.TaskModel, len(tasklist)) copy(list, tasklist) for i := len(list) - 1; i >= 0; i-- { - if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { list = append(list[:i], list[i+1:]...) } } @@ -174,7 +174,7 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { } // Update Infer Task Status - if task.TaskTypeDict == 11 || task.TaskTypeDict == 12 { + if task.TaskTypeDict == "11" || task.TaskTypeDict == "12" { UpdateInferTaskStatus(svc, task) return } diff --git a/api/internal/logic/core/pagelisttasklogic.go b/api/internal/logic/core/pagelisttasklogic.go index 29e5d000..fa56c9f0 100644 --- a/api/internal/logic/core/pagelisttasklogic.go +++ b/api/internal/logic/core/pagelisttasklogic.go @@ -90,7 +90,7 @@ func (l *PageListTaskLogic) updateTaskStatus(tasklist []*types.TaskModel, ch cha list := make([]*types.TaskModel, len(tasklist)) copy(list, tasklist) for i := len(list) - 1; i >= 0; i-- { - if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { list = append(list[:i], list[i+1:]...) } } @@ -228,7 +228,7 @@ func (l *PageListTaskLogic) updateAiTaskStatus(tasklist []*types.TaskModel, ch c list := make([]*types.TaskModel, len(tasklist)) copy(list, tasklist) for i := len(list) - 1; i >= 0; i-- { - if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { list = append(list[:i], list[i+1:]...) } } diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index 1b32bbb1..f10a3c53 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -3,25 +3,13 @@ package inference import ( "context" "errors" - "github.com/go-resty/resty/v2" "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" - "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/inference" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" - "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" - "k8s.io/apimachinery/pkg/util/json" - "log" - "math/rand" - "mime/multipart" "net/http" - "strconv" - "sync" - "time" ) type ImageInferenceLogic struct { @@ -55,10 +43,7 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere StaticWeightMap: req.StaticWeightMap, } - var ts []struct { - imageResult *types.ImageResult - file multipart.File - } + var ts []*inference.ImageFile uploadedFiles := r.MultipartForm.File @@ -78,14 +63,11 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere defer file.Close() var ir types.ImageResult ir.ImageName = header.Filename - t := struct { - imageResult *types.ImageResult - file multipart.File - }{ - imageResult: &ir, - file: file, + t := inference.ImageFile{ + ImageResult: &ir, + File: file, } - ts = append(ts, t) + ts = append(ts, &t) } _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] @@ -108,396 +90,32 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere return nil, err } - results, err := infer(opt, clusters, ts, l.svcCtx, l.ctx) - if err != nil { - return nil, err - } - resp.InferResults = results - - return resp, nil -} - -func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []struct { - imageResult *types.ImageResult - file multipart.File -}, svcCtx *svc.ServiceContext, ctx context.Context) ([]*types.ImageResult, error) { - if clusters == nil || len(clusters) == 0 { return nil, errors.New("clusters is nil") } - for i := len(clusters) - 1; i >= 0; i-- { - if clusters[i].Replicas == 0 { - clusters = append(clusters[:i], clusters[i+1:]...) - } - } - - var wg sync.WaitGroup - var cluster_ch = make(chan struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 - }, len(clusters)) - - var cs []struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 - } - collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] - //save task var synergystatus int64 if len(clusters) > 1 { synergystatus = 1 } - strategyCode, err := svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy) + strategyCode, err := l.svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy) if err != nil { return nil, err } - adapterName, err := svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) + adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) if err != nil { return nil, err } - id, err := svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11") + id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11") if err != nil { return nil, err } - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中") + l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中") - //save taskai - for _, c := range clusters { - clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) - opt.Replica = c.Replicas - err := svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") - if err != nil { - return nil, err - } - } + go l.svcCtx.Scheduler.AiService.ImageInfer(opt, id, adapterName, clusters, ts, l.ctx) - for _, cluster := range clusters { - wg.Add(1) - c := cluster - go func() { - imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt) - for i, _ := range imageUrls { - imageUrls[i].Url = imageUrls[i].Url + storeLink.FORWARD_SLASH + "image" - } - if err != nil { - wg.Done() - return - } - clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) - - s := struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 - }{ - urls: imageUrls, - clusterId: c.ClusterId, - clusterName: clusterName, - imageNum: c.Replicas, - } - - cluster_ch <- s - wg.Done() - return - }() - } - wg.Wait() - close(cluster_ch) - - for s := range cluster_ch { - cs = append(cs, s) - } - - var aiTaskList []*models.TaskAi - tx := svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) - if tx.Error != nil { - return nil, tx.Error - - } - - //no cluster available - if len(cs) == 0 { - for _, t := range aiTaskList { - t.Status = constants.Failed - t.EndTime = time.Now().Format(time.RFC3339) - err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - } - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") - return nil, errors.New("image infer task failed") - } - - //change cluster status - if len(clusters) != len(cs) { - var acs []*strategy.AssignedCluster - var rcs []*strategy.AssignedCluster - for _, cluster := range clusters { - if contains(cs, cluster.ClusterId) { - var ac *strategy.AssignedCluster - ac = cluster - rcs = append(rcs, ac) - } else { - var ac *strategy.AssignedCluster - ac = cluster - acs = append(acs, ac) - } - } - - // update failed cluster status - for _, ac := range acs { - for _, t := range aiTaskList { - if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { - t.Status = constants.Failed - t.EndTime = time.Now().Format(time.RFC3339) - err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - } - } - } - - // update running cluster status - for _, ac := range rcs { - for _, t := range aiTaskList { - if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { - t.Status = constants.Running - err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - } - } - } - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") - } else { - for _, t := range aiTaskList { - t.Status = constants.Running - err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - } - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中") - } - - var result_ch = make(chan *types.ImageResult, len(ts)) - var results []*types.ImageResult - limit := make(chan bool, 7) - - var imageNumIdx int32 = 0 - var imageNumIdxEnd int32 = 0 - for _, c := range cs { - new_images := make([]struct { - imageResult *types.ImageResult - file multipart.File - }, len(ts)) - copy(new_images, ts) - - imageNumIdxEnd = imageNumIdxEnd + c.imageNum - new_images = new_images[imageNumIdx:imageNumIdxEnd] - imageNumIdx = imageNumIdx + c.imageNum - - wg.Add(len(new_images)) - go sendInferReq(new_images, c, &wg, result_ch, limit) - } - wg.Wait() - close(result_ch) - - for s := range result_ch { - results = append(results, s) - } - - //sort.Slice(results, func(p, q int) bool { - // return results[p].ClusterName < results[q].ClusterName - //}) - - //save ai sub tasks - for _, r := range results { - for _, task := range aiTaskList { - if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { - taskAiSub := models.TaskAiSub{ - TaskId: id, - TaskName: task.Name, - TaskAiId: task.TaskId, - TaskAiName: task.Name, - ImageName: r.ImageName, - Result: r.ImageResult, - Card: r.Card, - ClusterId: task.ClusterId, - ClusterName: r.ClusterName, - } - tx := svcCtx.DbEngin.Table("task_ai_sub").Create(&taskAiSub) - if tx.Error != nil { - logx.Errorf(err.Error()) - } - } - } - } - - // update succeeded cluster status - var successStatusCount int - for _, c := range cs { - for _, t := range aiTaskList { - if c.clusterId == strconv.Itoa(int(t.ClusterId)) { - t.Status = constants.Completed - t.EndTime = time.Now().Format(time.RFC3339) - err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - successStatusCount++ - } else { - continue - } - } - } - - if len(cs) == successStatusCount { - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") - } else { - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") - } - - return results, nil -} - -func sendInferReq(images []struct { - imageResult *types.ImageResult - file multipart.File -}, cluster struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 -}, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { - for _, image := range images { - limit <- true - go func(t struct { - imageResult *types.ImageResult - file multipart.File - }, c struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 - }) { - if len(c.urls) == 1 { - r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName, c.clusterName) - if err != nil { - t.imageResult.ImageResult = err.Error() - t.imageResult.ClusterId = c.clusterId - t.imageResult.ClusterName = c.clusterName - t.imageResult.Card = c.urls[0].Card - ch <- t.imageResult - wg.Done() - <-limit - return - } - t.imageResult.ImageResult = r - t.imageResult.ClusterId = c.clusterId - t.imageResult.ClusterName = c.clusterName - t.imageResult.Card = c.urls[0].Card - - ch <- t.imageResult - wg.Done() - <-limit - return - } else { - idx := rand.Intn(len(c.urls)) - r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName, c.clusterName) - if err != nil { - t.imageResult.ImageResult = err.Error() - t.imageResult.ClusterId = c.clusterId - t.imageResult.ClusterName = c.clusterName - t.imageResult.Card = c.urls[idx].Card - ch <- t.imageResult - wg.Done() - <-limit - return - } - t.imageResult.ImageResult = r - t.imageResult.ClusterId = c.clusterId - t.imageResult.ClusterName = c.clusterName - t.imageResult.Card = c.urls[idx].Card - - ch <- t.imageResult - wg.Done() - <-limit - return - } - }(image, cluster) - <-limit - } -} - -func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { - if clusterName == "鹏城云脑II-modelarts" { - r, err := getInferResultModelarts(url, file, fileName) - log.Printf("图形识别url: %s", url) - if err != nil { - return "", err - } - return r, nil - } - var res Res - req := GetRestyRequest(20) - _, err := req. - SetFileReader("file", fileName, file). - SetResult(&res). - Post(url) - if err != nil { - return "", err - } - return res.Result, nil -} - -func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { - var res Res - body, err := utils.SendRequest("POST", url, file, fileName) - log.Printf("图形识别url: %s", url) - if err != nil { - return "", err - } - errjson := json.Unmarshal([]byte(body), &res) - if errjson != nil { - log.Fatalf("Error parsing JSON: %s", errjson) - } - log.Printf("推理结果: %s", res.Result) - return res.Result, nil -} - -func GetRestyRequest(timeoutSeconds int64) *resty.Request { - client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) - request := client.R() - return request -} - -type Res struct { - Result string `json:"result"` -} - -func contains(cs []struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 -}, e string) bool { - for _, c := range cs { - if c.clusterId == e { - return true - } - } - return false + return resp, nil } diff --git a/api/internal/scheduler/database/aiStorage.go b/api/internal/scheduler/database/aiStorage.go index 8e4105e7..9aba3e52 100644 --- a/api/internal/scheduler/database/aiStorage.go +++ b/api/internal/scheduler/database/aiStorage.go @@ -94,6 +94,15 @@ func (s *AiStorage) GetAiTasksByAdapterId(adapterId string) ([]*models.TaskAi, e return resp, nil } +func (s *AiStorage) GetAiTaskListById(id int64) ([]*models.TaskAi, error) { + var aiTaskList []*models.TaskAi + tx := s.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) + if tx.Error != nil { + return nil, tx.Error + } + return aiTaskList, nil +} + func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64, aiType string) (int64, error) { startTime := time.Now() // 构建主任务结构体 @@ -165,6 +174,14 @@ func (s *AiStorage) SaveAiTask(taskId int64, opt option.Option, adapterName stri return nil } +func (s *AiStorage) SaveAiTaskImageSubTask(ta *models.TaskAiSub) error { + tx := s.DbEngin.Table("task_ai_sub").Create(ta) + if tx.Error != nil { + return tx.Error + } + return nil +} + func (s *AiStorage) SaveClusterTaskQueue(adapterId string, clusterId string, queueNum int64) error { aId, err := strconv.ParseInt(adapterId, 10, 64) if err != nil { diff --git a/api/internal/scheduler/service/aiService.go b/api/internal/scheduler/service/aiService.go index e2f27be7..1e1d9df0 100644 --- a/api/internal/scheduler/service/aiService.go +++ b/api/internal/scheduler/service/aiService.go @@ -1,12 +1,17 @@ package service import ( + "context" + "fmt" "github.com/zeromicro/go-zero/zrpc" hpcacclient "gitlink.org.cn/JointCloud/pcm-ac/hpcacclient" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/config" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/database" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/executor" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-modelarts/client/imagesservice" @@ -86,6 +91,15 @@ func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[st return executorMap, collectorMap } +func (as *AiService) ImageInfer(opt *option.InferOption, id int64, adapterName string, clusters []*strategy.AssignedCluster, ts []*inference.ImageFile, ctx context.Context) { + + res, err := inference.Infer(opt, id, adapterName, clusters, ts, as.AiCollectorAdapterMap, as.Storage, ctx) + if err != nil { + return + } + fmt.Println(res) +} + //func (a *AiService) AddCluster() error { // //} diff --git a/api/internal/scheduler/service/inference/imageInfer.go b/api/internal/scheduler/service/inference/imageInfer.go new file mode 100644 index 00000000..732072b2 --- /dev/null +++ b/api/internal/scheduler/service/inference/imageInfer.go @@ -0,0 +1,385 @@ +package inference + +import ( + "context" + "encoding/json" + "errors" + "github.com/go-resty/resty/v2" + "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/database" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" + "log" + "math/rand" + "mime/multipart" + "sort" + "strconv" + "sync" + "time" +) + +type ImageFile struct { + ImageResult *types.ImageResult + File multipart.File +} + +func Infer(opt *option.InferOption, id int64, adapterName string, clusters []*strategy.AssignedCluster, ts []*ImageFile, aiCollectorAdapterMap map[string]map[string]collector.AiCollector, storage *database.AiStorage, ctx context.Context) ([]*types.ImageResult, error) { + + for i := len(clusters) - 1; i >= 0; i-- { + if clusters[i].Replicas == 0 { + clusters = append(clusters[:i], clusters[i+1:]...) + } + } + + var wg sync.WaitGroup + var cluster_ch = make(chan struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 + }, len(clusters)) + + var cs []struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 + } + collectorMap := aiCollectorAdapterMap[opt.AdapterId] + + //save taskai + for _, c := range clusters { + clusterName, _ := storage.GetClusterNameById(c.ClusterId) + opt.Replica = c.Replicas + err := storage.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") + if err != nil { + return nil, err + } + } + + for _, cluster := range clusters { + wg.Add(1) + c := cluster + go func() { + imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt) + for i, _ := range imageUrls { + imageUrls[i].Url = imageUrls[i].Url + storeLink.FORWARD_SLASH + "image" + } + if err != nil { + wg.Done() + return + } + clusterName, _ := storage.GetClusterNameById(c.ClusterId) + + s := struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 + }{ + urls: imageUrls, + clusterId: c.ClusterId, + clusterName: clusterName, + imageNum: c.Replicas, + } + + cluster_ch <- s + wg.Done() + return + }() + } + wg.Wait() + close(cluster_ch) + + for s := range cluster_ch { + cs = append(cs, s) + } + + aiTaskList, err := storage.GetAiTaskListById(id) + if err != nil { + return nil, err + } + + //no cluster available + if len(cs) == 0 { + for _, t := range aiTaskList { + t.Status = constants.Failed + t.EndTime = time.Now().Format(time.RFC3339) + err := storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") + return nil, errors.New("image infer task failed") + } + + //change cluster status + if len(clusters) != len(cs) { + var acs []*strategy.AssignedCluster + var rcs []*strategy.AssignedCluster + for _, cluster := range clusters { + if contains(cs, cluster.ClusterId) { + var ac *strategy.AssignedCluster + ac = cluster + rcs = append(rcs, ac) + } else { + var ac *strategy.AssignedCluster + ac = cluster + acs = append(acs, ac) + } + } + + // update failed cluster status + for _, ac := range acs { + for _, t := range aiTaskList { + if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Failed + t.EndTime = time.Now().Format(time.RFC3339) + err := storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + } + } + + // update running cluster status + for _, ac := range rcs { + for _, t := range aiTaskList { + if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Running + err := storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + } + } + storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") + } else { + for _, t := range aiTaskList { + t.Status = constants.Running + err := storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中") + } + + var result_ch = make(chan *types.ImageResult, len(ts)) + var results []*types.ImageResult + limit := make(chan bool, 7) + + var imageNumIdx int32 = 0 + var imageNumIdxEnd int32 = 0 + for _, c := range cs { + new_images := make([]*ImageFile, len(ts)) + copy(new_images, ts) + + imageNumIdxEnd = imageNumIdxEnd + c.imageNum + new_images = new_images[imageNumIdx:imageNumIdxEnd] + imageNumIdx = imageNumIdx + c.imageNum + + wg.Add(len(new_images)) + go sendInferReq(new_images, c, &wg, result_ch, limit) + } + wg.Wait() + close(result_ch) + + for s := range result_ch { + results = append(results, s) + } + + sort.Slice(results, func(p, q int) bool { + return results[p].ClusterName < results[q].ClusterName + }) + + //save ai sub tasks + for _, r := range results { + for _, task := range aiTaskList { + if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { + taskAiSub := models.TaskAiSub{ + TaskId: id, + TaskName: task.Name, + TaskAiId: task.TaskId, + TaskAiName: task.Name, + ImageName: r.ImageName, + Result: r.ImageResult, + Card: r.Card, + ClusterId: task.ClusterId, + ClusterName: r.ClusterName, + } + err := storage.SaveAiTaskImageSubTask(&taskAiSub) + if err != nil { + panic(err) + } + } + } + } + + // update succeeded cluster status + var successStatusCount int + for _, c := range cs { + for _, t := range aiTaskList { + if c.clusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Completed + t.EndTime = time.Now().Format(time.RFC3339) + err := storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + successStatusCount++ + } else { + continue + } + } + } + + if len(cs) == successStatusCount { + storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") + } else { + storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") + } + + return results, nil +} + +func sendInferReq(images []*ImageFile, cluster struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 +}, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { + for _, image := range images { + limit <- true + go func(t *ImageFile, c struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 + }) { + if len(c.urls) == 1 { + r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) + if err != nil { + t.ImageResult.ImageResult = err.Error() + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[0].Card + ch <- t.ImageResult + wg.Done() + <-limit + return + } + t.ImageResult.ImageResult = r + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[0].Card + + ch <- t.ImageResult + wg.Done() + <-limit + return + } else { + idx := rand.Intn(len(c.urls)) + r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) + if err != nil { + t.ImageResult.ImageResult = err.Error() + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[idx].Card + ch <- t.ImageResult + wg.Done() + <-limit + return + } + t.ImageResult.ImageResult = r + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[idx].Card + + ch <- t.ImageResult + wg.Done() + <-limit + return + } + }(image, cluster) + <-limit + } +} + +func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { + if clusterName == "鹏城云脑II-modelarts" { + r, err := getInferResultModelarts(url, file, fileName) + if err != nil { + return "", err + } + return r, nil + } + var res Res + req := GetRestyRequest(20) + _, err := req. + SetFileReader("file", fileName, file). + SetResult(&res). + Post(url) + if err != nil { + return "", err + } + return res.Result, nil +} + +func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { + var res Res + /* req := GetRestyRequest(20) + _, err := req. + SetFileReader("file", fileName, file). + SetHeaders(map[string]string{ + "ak": "UNEHPHO4Z7YSNPKRXFE4", + "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", + }). + SetResult(&res). + Post(url) + if err != nil { + return "", err + }*/ + body, err := utils.SendRequest("POST", url, file, fileName) + if err != nil { + return "", err + } + errjson := json.Unmarshal([]byte(body), &res) + if errjson != nil { + log.Fatalf("Error parsing JSON: %s", errjson) + } + return res.Result, nil +} + +func GetRestyRequest(timeoutSeconds int64) *resty.Request { + client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) + request := client.R() + return request +} + +type Res struct { + Result string `json:"result"` +} + +func contains(cs []struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 +}, e string) bool { + for _, c := range cs { + if c.clusterId == e { + return true + } + } + return false +}