diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index 5230c0be..09fbf0e5 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -13,7 +13,9 @@ import ( "math/rand" "mime/multipart" "net/http" + "sort" "sync" + "time" ) type ImageInferenceLogic struct { @@ -88,7 +90,6 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere var strat strategy.Strategy switch opt.Strategy { case strategy.STATIC_WEIGHT: - //todo resources should match cluster StaticWeightMap strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) if err != nil { return nil, err @@ -145,6 +146,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s go func() { imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) if err != nil { + wg.Done() return } clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) @@ -199,6 +201,10 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s results = append(results, s) } + sort.Slice(results, func(p, q int) bool { + return results[p].ClusterName < results[q].ClusterName + }) + return results, nil } @@ -223,6 +229,8 @@ func sendInferReq(images []struct { r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName) if err != nil { t.imageResult.ImageResult = err.Error() + t.imageResult.ClusterName = c.clusterName + t.imageResult.Card = c.urls[0].Card ch <- t.imageResult wg.Done() return @@ -239,6 +247,8 @@ func sendInferReq(images []struct { r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName) if err != nil { t.imageResult.ImageResult = err.Error() + t.imageResult.ClusterName = c.clusterName + t.imageResult.Card = c.urls[idx].Card ch <- t.imageResult wg.Done() return @@ -257,7 +267,7 @@ func sendInferReq(images []struct { func getInferResult(url string, file multipart.File, fileName string) (string, error) { var res Res - req := GetACHttpRequest() + req := GetRestyRequest(10) _, err := req. SetFileReader("file", fileName, file). SetResult(&res). @@ -269,8 +279,8 @@ func getInferResult(url string, file multipart.File, fileName string) (string, e return res.Result, nil } -func GetACHttpRequest() *resty.Request { - client := resty.New() +func GetRestyRequest(timeoutSeconds int64) *resty.Request { + client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) request := client.R() return request } diff --git a/api/internal/storeLink/octopus.go b/api/internal/storeLink/octopus.go index e84e54fd..ff25a0d3 100644 --- a/api/internal/storeLink/octopus.go +++ b/api/internal/storeLink/octopus.go @@ -872,21 +872,28 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec } func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { + req := &octopus.GetNotebookListReq{ + Platform: o.platform, + PageIndex: o.pageIndex, + PageSize: o.pageSize, + } + list, err := o.octopusRpc.GetNotebookList(ctx, req) + if err != nil { + return nil, err + } + var imageUrls []*collector.ImageInferUrl - imageUrl := &collector.ImageInferUrl{ - Url: "http://0.0.0.0:8888/image", - Card: "mlu", + for _, notebook := range list.Payload.GetNotebooks() { + if strings.Contains(notebook.AlgorithmName, option.ModelName) { + names := strings.Split(notebook.AlgorithmName, UNDERSCORE) + imageUrl := &collector.ImageInferUrl{ + Url: DOMAIN + notebook.Tasks[0].Url + FORWARD_SLASH + "image", + Card: names[2], + } + imageUrls = append(imageUrls, imageUrl) + } else { + continue + } } - imageUrl1 := &collector.ImageInferUrl{ - Url: "http://0.0.0.0:8888/image", - Card: "gcu", - } - imageUrl2 := &collector.ImageInferUrl{ - Url: "http://0.0.0.0:8888/image", - Card: "biv100", - } - imageUrls = append(imageUrls, imageUrl) - imageUrls = append(imageUrls, imageUrl1) - imageUrls = append(imageUrls, imageUrl2) return imageUrls, nil }