From 739948d184cb499a4548e7000660068f88827ed1 Mon Sep 17 00:00:00 2001 From: tzwang Date: Wed, 28 Aug 2024 17:24:39 +0800 Subject: [PATCH] updated imageinference logics Former-commit-id: 3dde5aa691fa23f3ebae3772620640d96ab70570 --- .../logic/inference/imageinferencelogic.go | 56 +++++++++++++------ .../imageInference/imageInference.go | 11 ++-- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/internal/logic/inference/imageinferencelogic.go b/internal/logic/inference/imageinferencelogic.go index 94b1ddad..2b26e8b9 100644 --- a/internal/logic/inference/imageinferencelogic.go +++ b/internal/logic/inference/imageinferencelogic.go @@ -10,7 +10,9 @@ import ( "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "net/http" + "strconv" ) type ImageInferenceLogic struct { @@ -34,21 +36,24 @@ func NewImageInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Im func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) { resp = &types.ImageInferenceResp{} - if len(req.Instances) == 0 { + if len(req.InstanceIds) == 0 { return nil, errors.New("instances are empty") } - opt := &option.InferOption{ - TaskName: req.TaskName, - TaskDesc: req.TaskDesc, - //AdapterId: req.AdapterId, - //AiClusterIds: req.AiClusterIds, - //ModelName: req.ModelName, - ModelType: req.ModelType, - Strategy: req.Strategy, - StaticWeightMap: req.StaticWeightMap, + var instanceList []*models.AiInferDeployInstance + for _, id := range req.InstanceIds { + instance, err := l.svcCtx.Scheduler.AiStorages.GetInferDeployInstanceById(id) + if err != nil { + return nil, err + } + instanceList = append(instanceList, instance) } + if len(instanceList) == 0 { + return nil, errors.New("instances are empty") + } + + // process uploaded images var ts []*imageInference.ImageFile uploadedFiles := r.MultipartForm.File @@ -76,17 +81,35 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere ts = append(ts, &t) } - //_, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] - //if !ok { - // return nil, errors.New("AdapterId does not exist") - //} - // + //single adapter logic + if len(req.StaticWeightMap) != 1 { + return nil, errors.New("staticWeightMap != 1") + } + + adapterId := strconv.FormatInt(instanceList[0].AdapterId, 10) + staticWeightMap, ok := req.StaticWeightMap[adapterId] + if !ok { + return nil, errors.New("set staticWeightMap failed") + } + + // create InferOption + opt := &option.InferOption{ + TaskName: req.TaskName, + TaskDesc: req.TaskDesc, + AdapterId: adapterId, + //AiClusterIds: req.AiClusterIds, + //ModelName: req.ModelName, + ModelType: req.ModelType, + Strategy: req.Strategy, + StaticWeightMap: staticWeightMap, + } adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) if err != nil { return nil, err } + // set strategy if opt.Strategy != "" { return nil, errors.New("strategy is empty") } @@ -116,7 +139,8 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere } } - imageInfer, err := imageInference.New(imageInference.NewImageClassification(), ts, clusters, req.Instances, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) + // create inference struct + imageInfer, err := imageInference.New(imageInference.NewImageClassification(), ts, clusters, instanceList, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) if err != nil { return nil, err } diff --git a/internal/scheduler/service/inference/imageInference/imageInference.go b/internal/scheduler/service/inference/imageInference/imageInference.go index a11b80f9..101ba45f 100644 --- a/internal/scheduler/service/inference/imageInference/imageInference.go +++ b/internal/scheduler/service/inference/imageInference/imageInference.go @@ -46,7 +46,7 @@ type ImageInference struct { inference IImageInference files []*ImageFile clusters []*strategy.AssignedCluster - instances []types.DeployInstance + instances []*models.AiInferDeployInstance opt *option.InferOption storage *database.AiStorage inferAdapter map[string]map[string]inference.ICluster @@ -58,7 +58,7 @@ func New( inference IImageInference, files []*ImageFile, clusters []*strategy.AssignedCluster, - instances []types.DeployInstance, + instances []*models.AiInferDeployInstance, opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster, @@ -199,9 +199,12 @@ func (i *ImageInference) filterClusters() ([]*FilteredCluster, error) { var inferurls []*inference.InferUrl var clustertype string for _, instance := range i.instances { - if cluster.ClusterId == instance.ClusterId { + clusterId := strconv.FormatInt(instance.ClusterId, 10) + adapterId := strconv.FormatInt(instance.AdapterId, 10) + + if cluster.ClusterId == clusterId { r := http.Request{} - deployInstance, err := i.inferAdapter[instance.AdapterId][instance.ClusterId].GetInferDeployInstance(r.Context(), instance.InstanceId) + deployInstance, err := i.inferAdapter[adapterId][clusterId].GetInferDeployInstance(r.Context(), instance.InstanceId) if err != nil { continue }