From 3c3d45fe6200239b9da46f6e66e745783a1d9cee Mon Sep 17 00:00:00 2001 From: tzwang Date: Wed, 28 Aug 2024 15:22:41 +0800 Subject: [PATCH] updated imageinference logics Former-commit-id: 1bb4a47ac943ee12a0d92ce37820af58b155641e --- .../logic/inference/imageinferencelogic.go | 71 +++++++++---------- .../imageInference/imageInference.go | 16 ++++- .../inference/textInference/textToImage.go | 2 +- .../inference/textInference/textToText.go | 2 + internal/storeLink/octopus.go | 5 ++ 5 files changed, 54 insertions(+), 42 deletions(-) diff --git a/internal/logic/inference/imageinferencelogic.go b/internal/logic/inference/imageinferencelogic.go index 88f93ded..94b1ddad 100644 --- a/internal/logic/inference/imageinferencelogic.go +++ b/internal/logic/inference/imageinferencelogic.go @@ -82,48 +82,41 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere //} // - var cs []*strategy.AssignedCluster - var adapterName string - if opt.Strategy != "" { - var strat strategy.Strategy - switch opt.Strategy { - case strategy.STATIC_WEIGHT: - strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) - if err != nil { - return nil, err - } - default: - return nil, errors.New("no strategy has been chosen") - } - clusters, err := strat.Schedule() - if err != nil { - return nil, err - } - - if clusters == nil || len(clusters) == 0 { - return nil, errors.New("clusters is nil") - } - - for i := len(clusters) - 1; i >= 0; i-- { - if clusters[i].Replicas == 0 { - clusters = append(clusters[:i], clusters[i+1:]...) - } - } - - name, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) - if err != nil { - return nil, err - } - adapterName = name + adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) + if err != nil { + return nil, err } - //else { - // for i, instance := range req.Instances { - // - // } - //} + if opt.Strategy != "" { + return nil, errors.New("strategy is empty") + } - imageInfer, err := imageInference.New(imageInference.NewImageClassification(), ts, cs, req.Instances, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) + var strat strategy.Strategy + switch opt.Strategy { + case strategy.STATIC_WEIGHT: + strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) + if err != nil { + return nil, err + } + default: + return nil, errors.New("no strategy has been chosen") + } + clusters, err := strat.Schedule() + if err != nil { + return nil, err + } + + if clusters == nil || len(clusters) == 0 { + return nil, errors.New("clusters is nil") + } + + for i := len(clusters) - 1; i >= 0; i-- { + if clusters[i].Replicas == 0 { + clusters = append(clusters[:i], clusters[i+1:]...) + } + } + + imageInfer, err := imageInference.New(imageInference.NewImageClassification(), ts, clusters, req.Instances, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) if err != nil { return nil, err } diff --git a/internal/scheduler/service/inference/imageInference/imageInference.go b/internal/scheduler/service/inference/imageInference/imageInference.go index 8c45c5a1..a11b80f9 100644 --- a/internal/scheduler/service/inference/imageInference/imageInference.go +++ b/internal/scheduler/service/inference/imageInference/imageInference.go @@ -131,7 +131,7 @@ func (i *ImageInference) saveTask() (int64, error) { return 0, err } - i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "create", "任务创建中") + i.storage.AddNoticeInfo("", "", "", "", i.opt.TaskName, "create", "任务创建中") return id, nil } @@ -197,21 +197,33 @@ func (i *ImageInference) filterClusters() ([]*FilteredCluster, error) { var cs []*FilteredCluster for _, cluster := range i.clusters { var inferurls []*inference.InferUrl + var clustertype string for _, instance := range i.instances { if cluster.ClusterId == instance.ClusterId { r := http.Request{} deployInstance, err := i.inferAdapter[instance.AdapterId][instance.ClusterId].GetInferDeployInstance(r.Context(), instance.InstanceId) if err != nil { - return nil, err + continue } var url inference.InferUrl url.Url = deployInstance.InferUrl url.Card = deployInstance.InferCard inferurls = append(inferurls, &url) + + clustertype = deployInstance.ClusterType } } + if len(inferurls) == 0 { + continue + } + + i.inference.AppendRoute(inferurls) + var f FilteredCluster f.urls = inferurls + f.clusterName = cluster.ClusterName + f.clusterType = clustertype + f.imageNum = cluster.Replicas cs = append(cs, &f) } return cs, nil diff --git a/internal/scheduler/service/inference/textInference/textToImage.go b/internal/scheduler/service/inference/textInference/textToImage.go index 7fe2496c..bc803444 100644 --- a/internal/scheduler/service/inference/textInference/textToImage.go +++ b/internal/scheduler/service/inference/textInference/textToImage.go @@ -9,7 +9,7 @@ import ( ) const ( - TEXTTOIMAGE = "text-to-image" + TEXTTOIMAGE = "generate_image" TEXTTOIMAGE_AiTYPE = "14" ) diff --git a/internal/scheduler/service/inference/textInference/textToText.go b/internal/scheduler/service/inference/textInference/textToText.go index b795e766..3851a35e 100644 --- a/internal/scheduler/service/inference/textInference/textToText.go +++ b/internal/scheduler/service/inference/textInference/textToText.go @@ -79,9 +79,11 @@ func filterClusters(opt *option.InferOption, storage *database.AiStorage, inferA wg.Done() return } + for i, _ := range clusterInferUrl.InferUrls { clusterInferUrl.InferUrls[i].Url = clusterInferUrl.InferUrls[i].Url + inference.FORWARD_SLASH + CHAT } + clusterName, _ := storage.GetClusterNameById(cId) var f FilteredCluster diff --git a/internal/storeLink/octopus.go b/internal/storeLink/octopus.go index f2c4a3aa..0b2fc3fd 100644 --- a/internal/storeLink/octopus.go +++ b/internal/storeLink/octopus.go @@ -1154,11 +1154,16 @@ func (o *OctopusLink) GetInferDeployInstance(ctx context.Context, id string) (*i if resp.Payload == nil { return nil, errors.New("instance does not exist") } + + url := strings.Replace(resp.Payload.Notebook.Tasks[0].Url, FORWARD_SLASH, "", -1) + inferUrl := DOMAIN + url + ins.InstanceName = resp.Payload.Notebook.Name ins.InstanceId = resp.Payload.Notebook.Id ins.ClusterName = o.platform ins.Status = resp.Payload.Notebook.Status ins.ClusterType = TYPE_OCTOPUS + ins.InferUrl = inferUrl return ins, nil }