From 734280ea00f794209d87c96de03f7b0b4166dfa9 Mon Sep 17 00:00:00 2001 From: tzwang Date: Thu, 20 Jun 2024 20:39:10 +0800 Subject: [PATCH] added imageinfer api Former-commit-id: 96ae7d54f30190581c662c5f8519efc5d92b555b --- .../inference/imageinferencehandler.go | 14 +- .../logic/inference/imageinferencelogic.go | 262 +++++++++++++++++- .../schedulers/option/inferOption.go | 18 ++ .../scheduler/service/collector/collector.go | 11 +- api/internal/storeLink/modelarts.go | 11 + api/internal/storeLink/octopus.go | 20 ++ api/internal/storeLink/shuguangai.go | 11 + 7 files changed, 331 insertions(+), 16 deletions(-) create mode 100644 api/internal/scheduler/schedulers/option/inferOption.go diff --git a/api/internal/handler/inference/imageinferencehandler.go b/api/internal/handler/inference/imageinferencehandler.go index d617c790..04ae9f78 100644 --- a/api/internal/handler/inference/imageinferencehandler.go +++ b/api/internal/handler/inference/imageinferencehandler.go @@ -1,28 +1,24 @@ package inference import ( - "net/http" - "github.com/zeromicro/go-zero/rest/httpx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" + "net/http" ) func ImageInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.ImageInferenceReq if err := httpx.Parse(r, &req); err != nil { - httpx.ErrorCtx(r.Context(), w, err) + result.ParamErrorResult(r, w, err) return } l := inference.NewImageInferenceLogic(r.Context(), svcCtx) - resp, err := l.ImageInference(&req) - if err != nil { - httpx.ErrorCtx(r.Context(), w, err) - } else { - httpx.OkJsonCtx(r.Context(), w, resp) - } + resp, err := l.ImageInfer(r, &req) + result.HttpResult(r, w, resp, err) } } diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index 0a1f9d3e..5230c0be 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -2,11 +2,18 @@ package inference import ( "context" - + "errors" + "github.com/go-resty/resty/v2" + "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" - - "github.com/zeromicro/go-zero/core/logx" + "math/rand" + "mime/multipart" + "net/http" + "sync" ) type ImageInferenceLogic struct { @@ -24,7 +31,250 @@ func NewImageInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Im } func (l *ImageInferenceLogic) ImageInference(req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) { - // todo: add your logic here and delete this line - - return + return nil, nil +} + +func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) { + resp = &types.ImageInferenceResp{} + opt := &option.InferOption{ + TaskName: req.TaskName, + TaskDesc: req.TaskDesc, + AdapterId: req.AdapterId, + AiClusterIds: req.AiClusterIds, + ModelName: req.ModelName, + ModelType: req.ModelType, + Strategy: req.Strategy, + StaticWeightMap: req.StaticWeightMap, + } + + var ts []struct { + imageResult *types.ImageResult + file multipart.File + } + + uploadedFiles := r.MultipartForm.File + + if len(uploadedFiles) == 0 { + return nil, errors.New("Images does not exist") + } + + if len(uploadedFiles["images"]) == 0 { + return nil, errors.New("Images does not exist") + } + + for _, header := range uploadedFiles["images"] { + file, err := header.Open() + if err != nil { + return nil, err + } + defer file.Close() + var ir types.ImageResult + ir.ImageName = header.Filename + t := struct { + imageResult *types.ImageResult + file multipart.File + }{ + imageResult: &ir, + file: file, + } + ts = append(ts, t) + } + + _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] + if !ok { + return nil, errors.New("AdapterId does not exist") + } + + var strat strategy.Strategy + switch opt.Strategy { + case strategy.STATIC_WEIGHT: + //todo resources should match cluster StaticWeightMap + strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) + if err != nil { + return nil, err + } + default: + return nil, errors.New("no strategy has been chosen") + } + clusters, err := strat.Schedule() + if err != nil { + return nil, err + } + + results, err := infer(opt, clusters, ts, l.svcCtx, l.ctx) + if err != nil { + return nil, err + } + resp.InferResults = results + + return resp, nil +} + +func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []struct { + imageResult *types.ImageResult + file multipart.File +}, svcCtx *svc.ServiceContext, ctx context.Context) ([]*types.ImageResult, error) { + + if clusters == nil || len(clusters) == 0 { + return nil, errors.New("clusters is nil") + } + + for i := len(clusters) - 1; i >= 0; i-- { + if clusters[i].Replicas == 0 { + clusters = append(clusters[:i], clusters[i+1:]...) + } + } + + var wg sync.WaitGroup + var cluster_ch = make(chan struct { + urls []*collector.ImageInferUrl + clusterName string + imageNum int32 + }, len(clusters)) + + var cs []struct { + urls []*collector.ImageInferUrl + clusterName string + imageNum int32 + } + collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] + + for _, cluster := range clusters { + wg.Add(1) + c := cluster + go func() { + imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) + if err != nil { + return + } + clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) + + s := struct { + urls []*collector.ImageInferUrl + clusterName string + imageNum int32 + }{ + urls: imageUrls, + clusterName: clusterName, + imageNum: c.Replicas, + } + + cluster_ch <- s + wg.Done() + return + }() + } + wg.Wait() + close(cluster_ch) + + for s := range cluster_ch { + cs = append(cs, s) + } + + var result_ch = make(chan *types.ImageResult, len(ts)) + var results []*types.ImageResult + + wg.Add(len(ts)) + + var imageNumIdx int32 = 0 + var imageNumIdxEnd int32 = 0 + for _, c := range cs { + new_images := make([]struct { + imageResult *types.ImageResult + file multipart.File + }, len(ts)) + copy(new_images, ts) + + imageNumIdxEnd = imageNumIdxEnd + c.imageNum + new_images = new_images[imageNumIdx:imageNumIdxEnd] + imageNumIdx = imageNumIdx + c.imageNum + + go sendInferReq(new_images, c, &wg, result_ch) + } + wg.Wait() + + close(result_ch) + + for s := range result_ch { + results = append(results, s) + } + + return results, nil +} + +func sendInferReq(images []struct { + imageResult *types.ImageResult + file multipart.File +}, cluster struct { + urls []*collector.ImageInferUrl + clusterName string + imageNum int32 +}, wg *sync.WaitGroup, ch chan<- *types.ImageResult) { + for _, image := range images { + go func(t struct { + imageResult *types.ImageResult + file multipart.File + }, c struct { + urls []*collector.ImageInferUrl + clusterName string + imageNum int32 + }) { + if len(c.urls) == 1 { + r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName) + if err != nil { + t.imageResult.ImageResult = err.Error() + ch <- t.imageResult + wg.Done() + return + } + t.imageResult.ImageResult = r + t.imageResult.ClusterName = c.clusterName + t.imageResult.Card = c.urls[0].Card + + ch <- t.imageResult + wg.Done() + return + } else { + idx := rand.Intn(len(c.urls)) + r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName) + if err != nil { + t.imageResult.ImageResult = err.Error() + ch <- t.imageResult + wg.Done() + return + } + t.imageResult.ImageResult = r + t.imageResult.ClusterName = c.clusterName + t.imageResult.Card = c.urls[idx].Card + + ch <- t.imageResult + wg.Done() + return + } + }(image, cluster) + } +} + +func getInferResult(url string, file multipart.File, fileName string) (string, error) { + var res Res + req := GetACHttpRequest() + _, err := req. + SetFileReader("file", fileName, file). + SetResult(&res). + Post(url) + + if err != nil { + return "", err + } + return res.Result, nil +} + +func GetACHttpRequest() *resty.Request { + client := resty.New() + request := client.R() + return request +} + +type Res struct { + Result string `json:"result"` } diff --git a/api/internal/scheduler/schedulers/option/inferOption.go b/api/internal/scheduler/schedulers/option/inferOption.go new file mode 100644 index 00000000..b576eb0f --- /dev/null +++ b/api/internal/scheduler/schedulers/option/inferOption.go @@ -0,0 +1,18 @@ +package option + +type InferOption struct { + TaskName string `json:"taskName"` + TaskDesc string `json:"taskDesc"` + ModelName string `json:"modelName"` + ModelType string `json:"modelType"` + AdapterId string `json:"adapterId"` + AiClusterIds []string `json:"aiClusterIds"` + ResourceType string `json:"resourceType"` + ComputeCard string `json:"card"` + Strategy string `json:"strategy"` + StaticWeightMap map[string]int32 `json:"staticWeightMap,optional"` + Params []string `json:"params,optional"` + Envs []string `json:"envs,optional"` + Cmd string `json:"cmd,optional"` + Replica int32 `json:"replicas,optional"` +} diff --git a/api/internal/scheduler/service/collector/collector.go b/api/internal/scheduler/service/collector/collector.go index 2c8d51a8..1fa777ed 100644 --- a/api/internal/scheduler/service/collector/collector.go +++ b/api/internal/scheduler/service/collector/collector.go @@ -1,6 +1,9 @@ package collector -import "context" +import ( + "context" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" +) type AiCollector interface { GetResourceStats(ctx context.Context) (*ResourceStats, error) @@ -12,6 +15,12 @@ type AiCollector interface { UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error GetComputeCards(ctx context.Context) ([]string, error) GetUserBalance(ctx context.Context) (float64, error) + GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*ImageInferUrl, error) +} + +type ImageInferUrl struct { + Url string + Card string } type ResourceStats struct { diff --git a/api/internal/storeLink/modelarts.go b/api/internal/storeLink/modelarts.go index 0652c8a5..80b0a193 100644 --- a/api/internal/storeLink/modelarts.go +++ b/api/internal/storeLink/modelarts.go @@ -376,3 +376,14 @@ func (m *ModelArtsLink) generateAlgorithmId(ctx context.Context, option *option. return errors.New("failed to get AlgorithmId") } + +func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { + var imageUrls []*collector.ImageInferUrl + imageUrl := &collector.ImageInferUrl{ + Url: "http://0.0.0.0:8888/image", + Card: "npu", + } + imageUrls = append(imageUrls, imageUrl) + + return imageUrls, nil +} diff --git a/api/internal/storeLink/octopus.go b/api/internal/storeLink/octopus.go index 53c3652c..e84e54fd 100644 --- a/api/internal/storeLink/octopus.go +++ b/api/internal/storeLink/octopus.go @@ -870,3 +870,23 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec } return errors.New("set ResourceId error") } + +func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { + var imageUrls []*collector.ImageInferUrl + imageUrl := &collector.ImageInferUrl{ + Url: "http://0.0.0.0:8888/image", + Card: "mlu", + } + imageUrl1 := &collector.ImageInferUrl{ + Url: "http://0.0.0.0:8888/image", + Card: "gcu", + } + imageUrl2 := &collector.ImageInferUrl{ + Url: "http://0.0.0.0:8888/image", + Card: "biv100", + } + imageUrls = append(imageUrls, imageUrl) + imageUrls = append(imageUrls, imageUrl1) + imageUrls = append(imageUrls, imageUrl2) + return imageUrls, nil +} diff --git a/api/internal/storeLink/shuguangai.go b/api/internal/storeLink/shuguangai.go index 03cb8928..d29427d2 100644 --- a/api/internal/storeLink/shuguangai.go +++ b/api/internal/storeLink/shuguangai.go @@ -729,3 +729,14 @@ func (s *ShuguangAi) generateParams(option *option.AiOption) error { return errors.New("failed to set params") } + +func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { + var imageUrls []*collector.ImageInferUrl + imageUrl := &collector.ImageInferUrl{ + Url: "http://0.0.0.0:8888/image", + Card: "dcu", + } + imageUrls = append(imageUrls, imageUrl) + + return imageUrls, nil +}