From bee190493f0a577e9ed05d34ad72a44f22c57a3d Mon Sep 17 00:00:00 2001 From: tzwang Date: Wed, 26 Jun 2024 16:31:09 +0800 Subject: [PATCH 1/2] updated taskmodel api type Former-commit-id: 8d89e912f552e5f3ca2390610186905b2886578d --- api/desc/core/pcm-core.api | 4 ++-- api/internal/types/types.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/desc/core/pcm-core.api b/api/desc/core/pcm-core.api index 516409cc..f8fb1ac7 100644 --- a/api/desc/core/pcm-core.api +++ b/api/desc/core/pcm-core.api @@ -404,8 +404,8 @@ type ( TenantId string `json:"tenantId,omitempty" db:"tenant_id"` CreatedTime string `json:"createdTime,omitempty" db:"created_time" gorm:"autoCreateTime"` UpdatedTime string `json:"updatedTime,omitempty" db:"updated_time"` - AdapterTypeDict int `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值 - TaskTypeDict int `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值 + AdapterTypeDict string `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值 + TaskTypeDict string `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值 } ) diff --git a/api/internal/types/types.go b/api/internal/types/types.go index 7ab16536..0b53492c 100644 --- a/api/internal/types/types.go +++ b/api/internal/types/types.go @@ -347,8 +347,8 @@ type TaskModel struct { TenantId string `json:"tenantId,omitempty" db:"tenant_id"` CreatedTime string `json:"createdTime,omitempty" db:"created_time" gorm:"autoCreateTime"` UpdatedTime string `json:"updatedTime,omitempty" db:"updated_time"` - AdapterTypeDict int `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值 - TaskTypeDict int `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值 + AdapterTypeDict string `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值 + TaskTypeDict string `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值 } type TaskDetailReq struct { From 28e9deea2c3c6a34ec5c20ebc7716e04e5bf6055 Mon Sep 17 00:00:00 2001 From: tzwang Date: Wed, 26 Jun 2024 18:35:15 +0800 Subject: [PATCH 2/2] fix imageinfer api bug Former-commit-id: 1b914196fd8e5069d25c4e39994720d501cd38f5 --- api/internal/cron/aiCronTask.go | 6 +- api/internal/logic/core/pagelisttasklogic.go | 4 +- .../logic/inference/imageinferencelogic.go | 406 +----------------- api/internal/scheduler/database/aiStorage.go | 17 + api/internal/scheduler/service/aiService.go | 14 + .../scheduler/service/inference/imageInfer.go | 385 +++++++++++++++++ 6 files changed, 433 insertions(+), 399 deletions(-) create mode 100644 api/internal/scheduler/service/inference/imageInfer.go diff --git a/api/internal/cron/aiCronTask.go b/api/internal/cron/aiCronTask.go index 254a16e6..fdf2a767 100644 --- a/api/internal/cron/aiCronTask.go +++ b/api/internal/cron/aiCronTask.go @@ -58,7 +58,7 @@ func UpdateAiTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { list := make([]*types.TaskModel, len(tasklist)) copy(list, tasklist) for i := len(list) - 1; i >= 0; i-- { - if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { list = append(list[:i], list[i+1:]...) } } @@ -155,7 +155,7 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { list := make([]*types.TaskModel, len(tasklist)) copy(list, tasklist) for i := len(list) - 1; i >= 0; i-- { - if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { list = append(list[:i], list[i+1:]...) } } @@ -174,7 +174,7 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { } // Update Infer Task Status - if task.TaskTypeDict == 11 || task.TaskTypeDict == 12 { + if task.TaskTypeDict == "11" || task.TaskTypeDict == "12" { UpdateInferTaskStatus(svc, task) return } diff --git a/api/internal/logic/core/pagelisttasklogic.go b/api/internal/logic/core/pagelisttasklogic.go index 29e5d000..fa56c9f0 100644 --- a/api/internal/logic/core/pagelisttasklogic.go +++ b/api/internal/logic/core/pagelisttasklogic.go @@ -90,7 +90,7 @@ func (l *PageListTaskLogic) updateTaskStatus(tasklist []*types.TaskModel, ch cha list := make([]*types.TaskModel, len(tasklist)) copy(list, tasklist) for i := len(list) - 1; i >= 0; i-- { - if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { list = append(list[:i], list[i+1:]...) } } @@ -228,7 +228,7 @@ func (l *PageListTaskLogic) updateAiTaskStatus(tasklist []*types.TaskModel, ch c list := make([]*types.TaskModel, len(tasklist)) copy(list, tasklist) for i := len(list) - 1; i >= 0; i-- { - if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { list = append(list[:i], list[i+1:]...) } } diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index 1b32bbb1..f10a3c53 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -3,25 +3,13 @@ package inference import ( "context" "errors" - "github.com/go-resty/resty/v2" "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" - "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/inference" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" - "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" - "k8s.io/apimachinery/pkg/util/json" - "log" - "math/rand" - "mime/multipart" "net/http" - "strconv" - "sync" - "time" ) type ImageInferenceLogic struct { @@ -55,10 +43,7 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere StaticWeightMap: req.StaticWeightMap, } - var ts []struct { - imageResult *types.ImageResult - file multipart.File - } + var ts []*inference.ImageFile uploadedFiles := r.MultipartForm.File @@ -78,14 +63,11 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere defer file.Close() var ir types.ImageResult ir.ImageName = header.Filename - t := struct { - imageResult *types.ImageResult - file multipart.File - }{ - imageResult: &ir, - file: file, + t := inference.ImageFile{ + ImageResult: &ir, + File: file, } - ts = append(ts, t) + ts = append(ts, &t) } _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] @@ -108,396 +90,32 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere return nil, err } - results, err := infer(opt, clusters, ts, l.svcCtx, l.ctx) - if err != nil { - return nil, err - } - resp.InferResults = results - - return resp, nil -} - -func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []struct { - imageResult *types.ImageResult - file multipart.File -}, svcCtx *svc.ServiceContext, ctx context.Context) ([]*types.ImageResult, error) { - if clusters == nil || len(clusters) == 0 { return nil, errors.New("clusters is nil") } - for i := len(clusters) - 1; i >= 0; i-- { - if clusters[i].Replicas == 0 { - clusters = append(clusters[:i], clusters[i+1:]...) - } - } - - var wg sync.WaitGroup - var cluster_ch = make(chan struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 - }, len(clusters)) - - var cs []struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 - } - collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] - //save task var synergystatus int64 if len(clusters) > 1 { synergystatus = 1 } - strategyCode, err := svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy) + strategyCode, err := l.svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy) if err != nil { return nil, err } - adapterName, err := svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) + adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) if err != nil { return nil, err } - id, err := svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11") + id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11") if err != nil { return nil, err } - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中") + l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中") - //save taskai - for _, c := range clusters { - clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) - opt.Replica = c.Replicas - err := svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") - if err != nil { - return nil, err - } - } + go l.svcCtx.Scheduler.AiService.ImageInfer(opt, id, adapterName, clusters, ts, l.ctx) - for _, cluster := range clusters { - wg.Add(1) - c := cluster - go func() { - imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt) - for i, _ := range imageUrls { - imageUrls[i].Url = imageUrls[i].Url + storeLink.FORWARD_SLASH + "image" - } - if err != nil { - wg.Done() - return - } - clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) - - s := struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 - }{ - urls: imageUrls, - clusterId: c.ClusterId, - clusterName: clusterName, - imageNum: c.Replicas, - } - - cluster_ch <- s - wg.Done() - return - }() - } - wg.Wait() - close(cluster_ch) - - for s := range cluster_ch { - cs = append(cs, s) - } - - var aiTaskList []*models.TaskAi - tx := svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) - if tx.Error != nil { - return nil, tx.Error - - } - - //no cluster available - if len(cs) == 0 { - for _, t := range aiTaskList { - t.Status = constants.Failed - t.EndTime = time.Now().Format(time.RFC3339) - err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - } - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") - return nil, errors.New("image infer task failed") - } - - //change cluster status - if len(clusters) != len(cs) { - var acs []*strategy.AssignedCluster - var rcs []*strategy.AssignedCluster - for _, cluster := range clusters { - if contains(cs, cluster.ClusterId) { - var ac *strategy.AssignedCluster - ac = cluster - rcs = append(rcs, ac) - } else { - var ac *strategy.AssignedCluster - ac = cluster - acs = append(acs, ac) - } - } - - // update failed cluster status - for _, ac := range acs { - for _, t := range aiTaskList { - if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { - t.Status = constants.Failed - t.EndTime = time.Now().Format(time.RFC3339) - err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - } - } - } - - // update running cluster status - for _, ac := range rcs { - for _, t := range aiTaskList { - if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { - t.Status = constants.Running - err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - } - } - } - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") - } else { - for _, t := range aiTaskList { - t.Status = constants.Running - err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - } - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中") - } - - var result_ch = make(chan *types.ImageResult, len(ts)) - var results []*types.ImageResult - limit := make(chan bool, 7) - - var imageNumIdx int32 = 0 - var imageNumIdxEnd int32 = 0 - for _, c := range cs { - new_images := make([]struct { - imageResult *types.ImageResult - file multipart.File - }, len(ts)) - copy(new_images, ts) - - imageNumIdxEnd = imageNumIdxEnd + c.imageNum - new_images = new_images[imageNumIdx:imageNumIdxEnd] - imageNumIdx = imageNumIdx + c.imageNum - - wg.Add(len(new_images)) - go sendInferReq(new_images, c, &wg, result_ch, limit) - } - wg.Wait() - close(result_ch) - - for s := range result_ch { - results = append(results, s) - } - - //sort.Slice(results, func(p, q int) bool { - // return results[p].ClusterName < results[q].ClusterName - //}) - - //save ai sub tasks - for _, r := range results { - for _, task := range aiTaskList { - if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { - taskAiSub := models.TaskAiSub{ - TaskId: id, - TaskName: task.Name, - TaskAiId: task.TaskId, - TaskAiName: task.Name, - ImageName: r.ImageName, - Result: r.ImageResult, - Card: r.Card, - ClusterId: task.ClusterId, - ClusterName: r.ClusterName, - } - tx := svcCtx.DbEngin.Table("task_ai_sub").Create(&taskAiSub) - if tx.Error != nil { - logx.Errorf(err.Error()) - } - } - } - } - - // update succeeded cluster status - var successStatusCount int - for _, c := range cs { - for _, t := range aiTaskList { - if c.clusterId == strconv.Itoa(int(t.ClusterId)) { - t.Status = constants.Completed - t.EndTime = time.Now().Format(time.RFC3339) - err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - successStatusCount++ - } else { - continue - } - } - } - - if len(cs) == successStatusCount { - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") - } else { - svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") - } - - return results, nil -} - -func sendInferReq(images []struct { - imageResult *types.ImageResult - file multipart.File -}, cluster struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 -}, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { - for _, image := range images { - limit <- true - go func(t struct { - imageResult *types.ImageResult - file multipart.File - }, c struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 - }) { - if len(c.urls) == 1 { - r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName, c.clusterName) - if err != nil { - t.imageResult.ImageResult = err.Error() - t.imageResult.ClusterId = c.clusterId - t.imageResult.ClusterName = c.clusterName - t.imageResult.Card = c.urls[0].Card - ch <- t.imageResult - wg.Done() - <-limit - return - } - t.imageResult.ImageResult = r - t.imageResult.ClusterId = c.clusterId - t.imageResult.ClusterName = c.clusterName - t.imageResult.Card = c.urls[0].Card - - ch <- t.imageResult - wg.Done() - <-limit - return - } else { - idx := rand.Intn(len(c.urls)) - r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName, c.clusterName) - if err != nil { - t.imageResult.ImageResult = err.Error() - t.imageResult.ClusterId = c.clusterId - t.imageResult.ClusterName = c.clusterName - t.imageResult.Card = c.urls[idx].Card - ch <- t.imageResult - wg.Done() - <-limit - return - } - t.imageResult.ImageResult = r - t.imageResult.ClusterId = c.clusterId - t.imageResult.ClusterName = c.clusterName - t.imageResult.Card = c.urls[idx].Card - - ch <- t.imageResult - wg.Done() - <-limit - return - } - }(image, cluster) - <-limit - } -} - -func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { - if clusterName == "鹏城云脑II-modelarts" { - r, err := getInferResultModelarts(url, file, fileName) - log.Printf("图形识别url: %s", url) - if err != nil { - return "", err - } - return r, nil - } - var res Res - req := GetRestyRequest(20) - _, err := req. - SetFileReader("file", fileName, file). - SetResult(&res). - Post(url) - if err != nil { - return "", err - } - return res.Result, nil -} - -func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { - var res Res - body, err := utils.SendRequest("POST", url, file, fileName) - log.Printf("图形识别url: %s", url) - if err != nil { - return "", err - } - errjson := json.Unmarshal([]byte(body), &res) - if errjson != nil { - log.Fatalf("Error parsing JSON: %s", errjson) - } - log.Printf("推理结果: %s", res.Result) - return res.Result, nil -} - -func GetRestyRequest(timeoutSeconds int64) *resty.Request { - client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) - request := client.R() - return request -} - -type Res struct { - Result string `json:"result"` -} - -func contains(cs []struct { - urls []*collector.InferUrl - clusterId string - clusterName string - imageNum int32 -}, e string) bool { - for _, c := range cs { - if c.clusterId == e { - return true - } - } - return false + return resp, nil } diff --git a/api/internal/scheduler/database/aiStorage.go b/api/internal/scheduler/database/aiStorage.go index 8e4105e7..9aba3e52 100644 --- a/api/internal/scheduler/database/aiStorage.go +++ b/api/internal/scheduler/database/aiStorage.go @@ -94,6 +94,15 @@ func (s *AiStorage) GetAiTasksByAdapterId(adapterId string) ([]*models.TaskAi, e return resp, nil } +func (s *AiStorage) GetAiTaskListById(id int64) ([]*models.TaskAi, error) { + var aiTaskList []*models.TaskAi + tx := s.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) + if tx.Error != nil { + return nil, tx.Error + } + return aiTaskList, nil +} + func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64, aiType string) (int64, error) { startTime := time.Now() // 构建主任务结构体 @@ -165,6 +174,14 @@ func (s *AiStorage) SaveAiTask(taskId int64, opt option.Option, adapterName stri return nil } +func (s *AiStorage) SaveAiTaskImageSubTask(ta *models.TaskAiSub) error { + tx := s.DbEngin.Table("task_ai_sub").Create(ta) + if tx.Error != nil { + return tx.Error + } + return nil +} + func (s *AiStorage) SaveClusterTaskQueue(adapterId string, clusterId string, queueNum int64) error { aId, err := strconv.ParseInt(adapterId, 10, 64) if err != nil { diff --git a/api/internal/scheduler/service/aiService.go b/api/internal/scheduler/service/aiService.go index e2f27be7..1e1d9df0 100644 --- a/api/internal/scheduler/service/aiService.go +++ b/api/internal/scheduler/service/aiService.go @@ -1,12 +1,17 @@ package service import ( + "context" + "fmt" "github.com/zeromicro/go-zero/zrpc" hpcacclient "gitlink.org.cn/JointCloud/pcm-ac/hpcacclient" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/config" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/database" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/executor" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-modelarts/client/imagesservice" @@ -86,6 +91,15 @@ func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[st return executorMap, collectorMap } +func (as *AiService) ImageInfer(opt *option.InferOption, id int64, adapterName string, clusters []*strategy.AssignedCluster, ts []*inference.ImageFile, ctx context.Context) { + + res, err := inference.Infer(opt, id, adapterName, clusters, ts, as.AiCollectorAdapterMap, as.Storage, ctx) + if err != nil { + return + } + fmt.Println(res) +} + //func (a *AiService) AddCluster() error { // //} diff --git a/api/internal/scheduler/service/inference/imageInfer.go b/api/internal/scheduler/service/inference/imageInfer.go new file mode 100644 index 00000000..732072b2 --- /dev/null +++ b/api/internal/scheduler/service/inference/imageInfer.go @@ -0,0 +1,385 @@ +package inference + +import ( + "context" + "encoding/json" + "errors" + "github.com/go-resty/resty/v2" + "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/database" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" + "log" + "math/rand" + "mime/multipart" + "sort" + "strconv" + "sync" + "time" +) + +type ImageFile struct { + ImageResult *types.ImageResult + File multipart.File +} + +func Infer(opt *option.InferOption, id int64, adapterName string, clusters []*strategy.AssignedCluster, ts []*ImageFile, aiCollectorAdapterMap map[string]map[string]collector.AiCollector, storage *database.AiStorage, ctx context.Context) ([]*types.ImageResult, error) { + + for i := len(clusters) - 1; i >= 0; i-- { + if clusters[i].Replicas == 0 { + clusters = append(clusters[:i], clusters[i+1:]...) + } + } + + var wg sync.WaitGroup + var cluster_ch = make(chan struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 + }, len(clusters)) + + var cs []struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 + } + collectorMap := aiCollectorAdapterMap[opt.AdapterId] + + //save taskai + for _, c := range clusters { + clusterName, _ := storage.GetClusterNameById(c.ClusterId) + opt.Replica = c.Replicas + err := storage.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") + if err != nil { + return nil, err + } + } + + for _, cluster := range clusters { + wg.Add(1) + c := cluster + go func() { + imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt) + for i, _ := range imageUrls { + imageUrls[i].Url = imageUrls[i].Url + storeLink.FORWARD_SLASH + "image" + } + if err != nil { + wg.Done() + return + } + clusterName, _ := storage.GetClusterNameById(c.ClusterId) + + s := struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 + }{ + urls: imageUrls, + clusterId: c.ClusterId, + clusterName: clusterName, + imageNum: c.Replicas, + } + + cluster_ch <- s + wg.Done() + return + }() + } + wg.Wait() + close(cluster_ch) + + for s := range cluster_ch { + cs = append(cs, s) + } + + aiTaskList, err := storage.GetAiTaskListById(id) + if err != nil { + return nil, err + } + + //no cluster available + if len(cs) == 0 { + for _, t := range aiTaskList { + t.Status = constants.Failed + t.EndTime = time.Now().Format(time.RFC3339) + err := storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") + return nil, errors.New("image infer task failed") + } + + //change cluster status + if len(clusters) != len(cs) { + var acs []*strategy.AssignedCluster + var rcs []*strategy.AssignedCluster + for _, cluster := range clusters { + if contains(cs, cluster.ClusterId) { + var ac *strategy.AssignedCluster + ac = cluster + rcs = append(rcs, ac) + } else { + var ac *strategy.AssignedCluster + ac = cluster + acs = append(acs, ac) + } + } + + // update failed cluster status + for _, ac := range acs { + for _, t := range aiTaskList { + if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Failed + t.EndTime = time.Now().Format(time.RFC3339) + err := storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + } + } + + // update running cluster status + for _, ac := range rcs { + for _, t := range aiTaskList { + if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Running + err := storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + } + } + storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") + } else { + for _, t := range aiTaskList { + t.Status = constants.Running + err := storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中") + } + + var result_ch = make(chan *types.ImageResult, len(ts)) + var results []*types.ImageResult + limit := make(chan bool, 7) + + var imageNumIdx int32 = 0 + var imageNumIdxEnd int32 = 0 + for _, c := range cs { + new_images := make([]*ImageFile, len(ts)) + copy(new_images, ts) + + imageNumIdxEnd = imageNumIdxEnd + c.imageNum + new_images = new_images[imageNumIdx:imageNumIdxEnd] + imageNumIdx = imageNumIdx + c.imageNum + + wg.Add(len(new_images)) + go sendInferReq(new_images, c, &wg, result_ch, limit) + } + wg.Wait() + close(result_ch) + + for s := range result_ch { + results = append(results, s) + } + + sort.Slice(results, func(p, q int) bool { + return results[p].ClusterName < results[q].ClusterName + }) + + //save ai sub tasks + for _, r := range results { + for _, task := range aiTaskList { + if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { + taskAiSub := models.TaskAiSub{ + TaskId: id, + TaskName: task.Name, + TaskAiId: task.TaskId, + TaskAiName: task.Name, + ImageName: r.ImageName, + Result: r.ImageResult, + Card: r.Card, + ClusterId: task.ClusterId, + ClusterName: r.ClusterName, + } + err := storage.SaveAiTaskImageSubTask(&taskAiSub) + if err != nil { + panic(err) + } + } + } + } + + // update succeeded cluster status + var successStatusCount int + for _, c := range cs { + for _, t := range aiTaskList { + if c.clusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Completed + t.EndTime = time.Now().Format(time.RFC3339) + err := storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + successStatusCount++ + } else { + continue + } + } + } + + if len(cs) == successStatusCount { + storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") + } else { + storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") + } + + return results, nil +} + +func sendInferReq(images []*ImageFile, cluster struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 +}, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { + for _, image := range images { + limit <- true + go func(t *ImageFile, c struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 + }) { + if len(c.urls) == 1 { + r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) + if err != nil { + t.ImageResult.ImageResult = err.Error() + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[0].Card + ch <- t.ImageResult + wg.Done() + <-limit + return + } + t.ImageResult.ImageResult = r + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[0].Card + + ch <- t.ImageResult + wg.Done() + <-limit + return + } else { + idx := rand.Intn(len(c.urls)) + r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) + if err != nil { + t.ImageResult.ImageResult = err.Error() + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[idx].Card + ch <- t.ImageResult + wg.Done() + <-limit + return + } + t.ImageResult.ImageResult = r + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[idx].Card + + ch <- t.ImageResult + wg.Done() + <-limit + return + } + }(image, cluster) + <-limit + } +} + +func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { + if clusterName == "鹏城云脑II-modelarts" { + r, err := getInferResultModelarts(url, file, fileName) + if err != nil { + return "", err + } + return r, nil + } + var res Res + req := GetRestyRequest(20) + _, err := req. + SetFileReader("file", fileName, file). + SetResult(&res). + Post(url) + if err != nil { + return "", err + } + return res.Result, nil +} + +func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { + var res Res + /* req := GetRestyRequest(20) + _, err := req. + SetFileReader("file", fileName, file). + SetHeaders(map[string]string{ + "ak": "UNEHPHO4Z7YSNPKRXFE4", + "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", + }). + SetResult(&res). + Post(url) + if err != nil { + return "", err + }*/ + body, err := utils.SendRequest("POST", url, file, fileName) + if err != nil { + return "", err + } + errjson := json.Unmarshal([]byte(body), &res) + if errjson != nil { + log.Fatalf("Error parsing JSON: %s", errjson) + } + return res.Result, nil +} + +func GetRestyRequest(timeoutSeconds int64) *resty.Request { + client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) + request := client.R() + return request +} + +type Res struct { + Result string `json:"result"` +} + +func contains(cs []struct { + urls []*collector.InferUrl + clusterId string + clusterName string + imageNum int32 +}, e string) bool { + for _, c := range cs { + if c.clusterId == e { + return true + } + } + return false +}