From a60e97134d6659006849d2d70ec8bd06d1f05d55 Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 27 Aug 2024 17:47:20 +0800 Subject: [PATCH] updated inference api logics Former-commit-id: 7c4b471eb2de5b6f6bbe41b4d5b62c940ac4ec4b --- .../getrunninginstancebyidhandler.go | 10 ++- .../inference/getrunninginstancebyidlogic.go | 10 ++- .../logic/inference/imageinferencelogic.go | 70 ++++++++++++------- internal/scheduler/database/aiStorage.go | 4 +- .../imageInference/imageInference.go | 29 +++++++- .../scheduler/service/inference/inference.go | 1 + 6 files changed, 87 insertions(+), 37 deletions(-) diff --git a/internal/handler/inference/getrunninginstancebyidhandler.go b/internal/handler/inference/getrunninginstancebyidhandler.go index 5861991d..ec8e3605 100644 --- a/internal/handler/inference/getrunninginstancebyidhandler.go +++ b/internal/handler/inference/getrunninginstancebyidhandler.go @@ -1,6 +1,7 @@ package inference import ( + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" "net/http" "github.com/zeromicro/go-zero/rest/httpx" @@ -13,16 +14,13 @@ func GetRunningInstanceByIdHandler(svcCtx *svc.ServiceContext) http.HandlerFunc return func(w http.ResponseWriter, r *http.Request) { var req types.GetRunningInstanceReq if err := httpx.Parse(r, &req); err != nil { - httpx.ErrorCtx(r.Context(), w, err) + result.ParamErrorResult(r, w, err) return } l := inference.NewGetRunningInstanceByIdLogic(r.Context(), svcCtx) resp, err := l.GetRunningInstanceById(&req) - if err != nil { - httpx.ErrorCtx(r.Context(), w, err) - } else { - httpx.OkJsonCtx(r.Context(), w, resp) - } + result.HttpResult(r, w, resp, err) + } } diff --git a/internal/logic/inference/getrunninginstancebyidlogic.go b/internal/logic/inference/getrunninginstancebyidlogic.go index ab4d75c0..80c25976 100644 --- a/internal/logic/inference/getrunninginstancebyidlogic.go +++ b/internal/logic/inference/getrunninginstancebyidlogic.go @@ -2,6 +2,7 @@ package inference import ( "context" + "strconv" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" @@ -24,7 +25,14 @@ func NewGetRunningInstanceByIdLogic(ctx context.Context, svcCtx *svc.ServiceCont } func (l *GetRunningInstanceByIdLogic) GetRunningInstanceById(req *types.GetRunningInstanceReq) (resp *types.GetRunningInstanceResp, err error) { - // todo: add your logic here and delete this line + resp = &types.GetRunningInstanceResp{} + id, err := strconv.ParseInt(req.Id, 10, 64) + if err != nil { + return nil, err + } + list, err := l.svcCtx.Scheduler.AiStorages.GetRunningDeployInstanceById(id, req.AdapterId) + + resp.List = list return } diff --git a/internal/logic/inference/imageinferencelogic.go b/internal/logic/inference/imageinferencelogic.go index 9cc749e7..88f93ded 100644 --- a/internal/logic/inference/imageinferencelogic.go +++ b/internal/logic/inference/imageinferencelogic.go @@ -34,6 +34,10 @@ func NewImageInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Im func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) { resp = &types.ImageInferenceResp{} + if len(req.Instances) == 0 { + return nil, errors.New("instances are empty") + } + opt := &option.InferOption{ TaskName: req.TaskName, TaskDesc: req.TaskDesc, @@ -72,42 +76,54 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere ts = append(ts, &t) } - _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] - if !ok { - return nil, errors.New("AdapterId does not exist") - } + //_, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] + //if !ok { + // return nil, errors.New("AdapterId does not exist") + //} + // - var strat strategy.Strategy - switch opt.Strategy { - case strategy.STATIC_WEIGHT: - strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) + var cs []*strategy.AssignedCluster + var adapterName string + if opt.Strategy != "" { + var strat strategy.Strategy + switch opt.Strategy { + case strategy.STATIC_WEIGHT: + strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) + if err != nil { + return nil, err + } + default: + return nil, errors.New("no strategy has been chosen") + } + clusters, err := strat.Schedule() if err != nil { return nil, err } - default: - return nil, errors.New("no strategy has been chosen") - } - clusters, err := strat.Schedule() - if err != nil { - return nil, err - } - if clusters == nil || len(clusters) == 0 { - return nil, errors.New("clusters is nil") - } - - for i := len(clusters) - 1; i >= 0; i-- { - if clusters[i].Replicas == 0 { - clusters = append(clusters[:i], clusters[i+1:]...) + if clusters == nil || len(clusters) == 0 { + return nil, errors.New("clusters is nil") } + + for i := len(clusters) - 1; i >= 0; i-- { + if clusters[i].Replicas == 0 { + clusters = append(clusters[:i], clusters[i+1:]...) + } + } + + name, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) + if err != nil { + return nil, err + } + adapterName = name } - adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) - if err != nil { - return nil, err - } + //else { + // for i, instance := range req.Instances { + // + // } + //} - imageInfer, err := imageInference.New(imageInference.NewImageClassification(), ts, clusters, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) + imageInfer, err := imageInference.New(imageInference.NewImageClassification(), ts, cs, req.Instances, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) if err != nil { return nil, err } diff --git a/internal/scheduler/database/aiStorage.go b/internal/scheduler/database/aiStorage.go index 5735c3ac..6f9c4285 100644 --- a/internal/scheduler/database/aiStorage.go +++ b/internal/scheduler/database/aiStorage.go @@ -584,9 +584,9 @@ func (s *AiStorage) SaveInferDeployTask(taskName string, modelName string, model return taskModel.Id, nil } -func (s *AiStorage) GetRunningDeployInstanceByModelNameAndAdapterId(modelType string, modelName string, adapterId string) ([]*models.AiInferDeployInstance, error) { +func (s *AiStorage) GetRunningDeployInstanceById(id int64, adapterId string) ([]*models.AiInferDeployInstance, error) { var list []*models.AiInferDeployInstance - tx := s.DbEngin.Raw("select * from ai_infer_deploy_instance where `model_type` = ? and `model_name` = ? and `adapter_id` = ? and `status` = 'Running'", modelType, modelName, adapterId).Scan(&list) + tx := s.DbEngin.Raw("select * from ai_infer_deploy_instance where `deploy_instance_task_id` = ? and `adapter_id` = ? and `status` = 'Running'", id, adapterId).Scan(&list) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error diff --git a/internal/scheduler/service/inference/imageInference/imageInference.go b/internal/scheduler/service/inference/imageInference/imageInference.go index ea569056..8c45c5a1 100644 --- a/internal/scheduler/service/inference/imageInference/imageInference.go +++ b/internal/scheduler/service/inference/imageInference/imageInference.go @@ -46,6 +46,7 @@ type ImageInference struct { inference IImageInference files []*ImageFile clusters []*strategy.AssignedCluster + instances []types.DeployInstance opt *option.InferOption storage *database.AiStorage inferAdapter map[string]map[string]inference.ICluster @@ -57,6 +58,7 @@ func New( inference IImageInference, files []*ImageFile, clusters []*strategy.AssignedCluster, + instances []types.DeployInstance, opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster, @@ -66,6 +68,7 @@ func New( inference: inference, files: files, clusters: clusters, + instances: instances, opt: opt, storage: storage, inferAdapter: inferAdapter, @@ -145,7 +148,7 @@ func (i *ImageInference) saveAiTask(id int64) error { return nil } -func (i *ImageInference) filterClusters() ([]*FilteredCluster, error) { +func (i *ImageInference) filterClustersTemp() ([]*FilteredCluster, error) { var wg sync.WaitGroup var ch = make(chan *FilteredCluster, len(i.clusters)) var cs []*FilteredCluster @@ -190,6 +193,30 @@ func (i *ImageInference) filterClusters() ([]*FilteredCluster, error) { return cs, nil } +func (i *ImageInference) filterClusters() ([]*FilteredCluster, error) { + var cs []*FilteredCluster + for _, cluster := range i.clusters { + var inferurls []*inference.InferUrl + for _, instance := range i.instances { + if cluster.ClusterId == instance.ClusterId { + r := http.Request{} + deployInstance, err := i.inferAdapter[instance.AdapterId][instance.ClusterId].GetInferDeployInstance(r.Context(), instance.InstanceId) + if err != nil { + return nil, err + } + var url inference.InferUrl + url.Url = deployInstance.InferUrl + url.Card = deployInstance.InferCard + inferurls = append(inferurls, &url) + } + } + var f FilteredCluster + f.urls = inferurls + cs = append(cs, &f) + } + return cs, nil +} + func (i *ImageInference) inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) { var wg sync.WaitGroup var ch = make(chan *types.ImageResult, len(i.files)) diff --git a/internal/scheduler/service/inference/inference.go b/internal/scheduler/service/inference/inference.go index 1e8442bc..4695ec07 100644 --- a/internal/scheduler/service/inference/inference.go +++ b/internal/scheduler/service/inference/inference.go @@ -46,6 +46,7 @@ type DeployInstance struct { ModelName string ModelType string InferCard string + InferUrl string ClusterName string ClusterType string Status string