From 3b030e661d1fea19768c6029c4980ef851c08628 Mon Sep 17 00:00:00 2001 From: tzwang Date: Thu, 18 Jul 2024 10:54:35 +0800 Subject: [PATCH] update imageinference Former-commit-id: a3cde2d1d366d5d6788f659964049ab383b5a7b4 --- .../logic/inference/imageinferencelogic.go | 33 +- .../inference/texttotextinferencelogic.go | 111 +---- internal/scheduler/service/aiService.go | 8 +- .../imageInference/imageClassification.go | 413 +-------------- .../imageInference/imageInference.go | 470 +++++++++++++++++- .../inference/imageInference/imageToText.go | 29 +- .../scheduler/service/inference/inference.go | 380 +------------- .../inference/textInference/textInference.go | 96 ++++ .../inference/textInference/textToImage.go | 48 ++ .../inference/textInference/textToText.go | 131 +++++ 10 files changed, 809 insertions(+), 910 deletions(-) create mode 100644 internal/scheduler/service/inference/textInference/textToImage.go create mode 100644 internal/scheduler/service/inference/textInference/textToText.go diff --git a/internal/logic/inference/imageinferencelogic.go b/internal/logic/inference/imageinferencelogic.go index d381b0f2..0ad1c127 100644 --- a/internal/logic/inference/imageinferencelogic.go +++ b/internal/logic/inference/imageinferencelogic.go @@ -5,11 +5,11 @@ import ( "errors" "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference/imageInference" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "net/http" ) @@ -102,44 +102,31 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere } } - //save task - var synergystatus int64 - if len(clusters) > 1 { - synergystatus = 1 - } - - strategyCode, err := l.svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy) - if err != nil { - return nil, err - } adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) if err != nil { return nil, err } - id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11") + + imageInfer, err := imageInference.New(imageInference.NewImageClassification(), ts, clusters, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) if err != nil { return nil, err } - l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中") + in := inference.Inference{ + In: imageInfer, + } - //save taskai - for _, c := range clusters { - clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) - opt.Replica = c.Replicas - err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") - if err != nil { - return nil, err - } + id, err := in.In.CreateTask() + if err != nil { + return nil, err } go func() { - ic, err := imageInference.NewImageClassification(ts, clusters, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, id, adapterName) + err := in.In.InferTask(id) if err != nil { logx.Errorf(err.Error()) return } - ic.Classify() }() return resp, nil diff --git a/internal/logic/inference/texttotextinferencelogic.go b/internal/logic/inference/texttotextinferencelogic.go index 83894dff..2507f2c6 100644 --- a/internal/logic/inference/texttotextinferencelogic.go +++ b/internal/logic/inference/texttotextinferencelogic.go @@ -6,14 +6,9 @@ import ( "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" - "gitlink.org.cn/JointCloud/pcm-coordinator/internal/storeLink" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference/textInference" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" - "strconv" - "sync" - "time" ) type TextToTextInferenceLogic struct { @@ -46,105 +41,29 @@ func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInfe return nil, errors.New("AdapterId does not exist") } - //save task - var synergystatus int64 - var strategyCode int64 adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) + + inType, err := textInference.NewTextToText(opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap) + if err != nil { + return nil, err + } + textInfer, err := textInference.New(inType, opt, l.svcCtx.Scheduler.AiStorages, l.svcCtx.Scheduler.AiService.InferenceAdapterMap, adapterName) if err != nil { return nil, err } - id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12") + in := inference.Inference{ + In: textInfer, + } + + id, err := in.In.CreateTask() if err != nil { return nil, err } - - var wg sync.WaitGroup - var cluster_ch = make(chan struct { - urls []*inference.InferUrl - clusterId string - clusterName string - }, len(opt.AiClusterIds)) - - var cs []struct { - urls []*inference.InferUrl - clusterId string - clusterName string + err = in.In.InferTask(id) + if err != nil { + return nil, err } - inferMap := l.svcCtx.Scheduler.AiService.InferenceAdapterMap[opt.AdapterId] - - //save taskai - for _, clusterId := range opt.AiClusterIds { - wg.Add(1) - go func(cId string) { - urls, err := inferMap[cId].GetInferUrl(l.ctx, opt) - if err != nil { - wg.Done() - return - } - clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId) - - s := struct { - urls []*inference.InferUrl - clusterId string - clusterName string - }{ - urls: urls, - clusterId: cId, - clusterName: clusterName, - } - - cluster_ch <- s - wg.Done() - return - }(clusterId) - } - wg.Wait() - close(cluster_ch) - - for s := range cluster_ch { - cs = append(cs, s) - } - - if len(cs) == 0 { - clusterId := opt.AiClusterIds[0] - clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(opt.AiClusterIds[0]) - err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, clusterId, clusterName, "", constants.Failed, "") - if err != nil { - return nil, err - } - l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") - } - - for _, c := range cs { - clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId) - err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "") - if err != nil { - return nil, err - } - } - - var aiTaskList []*models.TaskAi - tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) - if tx.Error != nil { - return nil, tx.Error - - } - - for i, t := range aiTaskList { - if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId { - t.Status = constants.Completed - t.EndTime = time.Now().Format(time.RFC3339) - url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat" - t.InferUrl = url - err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - logx.Errorf(tx.Error.Error()) - } - } - } - - l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") return resp, nil } diff --git a/internal/scheduler/service/aiService.go b/internal/scheduler/service/aiService.go index 4f9e6630..1af4f282 100644 --- a/internal/scheduler/service/aiService.go +++ b/internal/scheduler/service/aiService.go @@ -26,7 +26,7 @@ const ( type AiService struct { AiExecutorAdapterMap map[string]map[string]executor.AiExecutor AiCollectorAdapterMap map[string]map[string]collector.AiCollector - InferenceAdapterMap map[string]map[string]inference.Inference + InferenceAdapterMap map[string]map[string]inference.ICluster Storage *database.AiStorage mu sync.Mutex } @@ -40,7 +40,7 @@ func NewAiService(conf *config.Config, storages *database.AiStorage) (*AiService aiService := &AiService{ AiExecutorAdapterMap: make(map[string]map[string]executor.AiExecutor), AiCollectorAdapterMap: make(map[string]map[string]collector.AiCollector), - InferenceAdapterMap: make(map[string]map[string]inference.Inference), + InferenceAdapterMap: make(map[string]map[string]inference.ICluster), Storage: storages, } for _, id := range adapterIds { @@ -60,10 +60,10 @@ func NewAiService(conf *config.Config, storages *database.AiStorage) (*AiService return aiService, nil } -func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[string]executor.AiExecutor, map[string]collector.AiCollector, map[string]inference.Inference) { +func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[string]executor.AiExecutor, map[string]collector.AiCollector, map[string]inference.ICluster) { executorMap := make(map[string]executor.AiExecutor) collectorMap := make(map[string]collector.AiCollector) - inferenceMap := make(map[string]inference.Inference) + inferenceMap := make(map[string]inference.ICluster) for _, c := range clusters { switch c.Name { case OCTOPUS: diff --git a/internal/scheduler/service/inference/imageInference/imageClassification.go b/internal/scheduler/service/inference/imageInference/imageClassification.go index a9d10912..ad00245c 100644 --- a/internal/scheduler/service/inference/imageInference/imageClassification.go +++ b/internal/scheduler/service/inference/imageInference/imageClassification.go @@ -1,419 +1,26 @@ package imageInference -import ( - "encoding/json" - "errors" - "github.com/go-resty/resty/v2" - "github.com/zeromicro/go-zero/core/logx" - "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" - "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" - "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" - "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" - "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" - "log" - "math/rand" - "mime/multipart" - "net/http" - "sort" - "strconv" - "sync" - "time" -) +import "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" const ( - IMAGE = "image" - FORWARD_SLASH = "/" + CLASSIFICATION = "image" + CLASSIFICATION_AiTYPE = "11" ) -type ImageClassificationInterface interface { - Classify() ([]*types.ImageResult, error) -} - -type ImageFile struct { - ImageResult *types.ImageResult - File multipart.File -} - -type FilteredCluster struct { - urls []*inference.InferUrl - clusterId string - clusterName string - imageNum int32 -} - type ImageClassification struct { - files []*ImageFile - clusters []*strategy.AssignedCluster - opt *option.InferOption - storage *database.AiStorage - inferAdapter map[string]map[string]inference.Inference - errMap map[string]string - taskId int64 - adapterName string - aiTaskList []*models.TaskAi } -func NewImageClassification(files []*ImageFile, - clusters []*strategy.AssignedCluster, - opt *option.InferOption, - storage *database.AiStorage, - inferAdapter map[string]map[string]inference.Inference, - taskId int64, - adapterName string) (*ImageClassification, error) { - - aiTaskList, err := storage.GetAiTaskListById(taskId) - if err != nil || len(aiTaskList) == 0 { - return nil, err - } - return &ImageClassification{ - files: files, - clusters: clusters, - opt: opt, - storage: storage, - inferAdapter: inferAdapter, - taskId: taskId, - adapterName: adapterName, - errMap: make(map[string]string), - aiTaskList: aiTaskList, - }, nil +func NewImageClassification() *ImageClassification { + return &ImageClassification{} } -func (i *ImageClassification) Classify() ([]*types.ImageResult, error) { - clusters, err := i.filterClusters() - if err != nil { - return nil, err - } - err = i.updateStatus(clusters) - if err != nil { - return nil, err - } - results, err := i.inferImages(clusters) - if err != nil { - return nil, err - } - return results, nil -} - -func (i *ImageClassification) filterClusters() ([]*FilteredCluster, error) { - var wg sync.WaitGroup - var ch = make(chan *FilteredCluster, len(i.clusters)) - var cs []*FilteredCluster - var mutex sync.Mutex - - inferMap := i.inferAdapter[i.opt.AdapterId] - - for _, cluster := range i.clusters { - wg.Add(1) - c := cluster - go func() { - r := http.Request{} - imageUrls, err := inferMap[c.ClusterId].GetInferUrl(r.Context(), i.opt) - if err != nil { - mutex.Lock() - i.errMap[c.ClusterId] = err.Error() - mutex.Unlock() - wg.Done() - return - } - for i, _ := range imageUrls { - imageUrls[i].Url = imageUrls[i].Url + FORWARD_SLASH + IMAGE - } - clusterName, _ := i.storage.GetClusterNameById(c.ClusterId) - - var f FilteredCluster - f.urls = imageUrls - f.clusterId = c.ClusterId - f.clusterName = clusterName - f.imageNum = c.Replicas - - ch <- &f - wg.Done() - return - }() - } - wg.Wait() - close(ch) - - for s := range ch { - cs = append(cs, s) - } - return cs, nil -} - -func (i *ImageClassification) inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) { - var wg sync.WaitGroup - var ch = make(chan *types.ImageResult, len(i.files)) - var results []*types.ImageResult - limit := make(chan bool, 7) - - var imageNumIdx int32 = 0 - var imageNumIdxEnd int32 = 0 - for _, c := range cs { - new_images := make([]*ImageFile, len(i.files)) - copy(new_images, i.files) - - imageNumIdxEnd = imageNumIdxEnd + c.imageNum - new_images = new_images[imageNumIdx:imageNumIdxEnd] - imageNumIdx = imageNumIdx + c.imageNum - - wg.Add(len(new_images)) - go sendInferReq(new_images, c, &wg, ch, limit) - } - wg.Wait() - close(ch) - - for s := range ch { - results = append(results, s) - } - - sort.Slice(results, func(p, q int) bool { - return results[p].ClusterName < results[q].ClusterName - }) - - //save ai sub tasks - for _, r := range results { - for _, task := range i.aiTaskList { - if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { - taskAiSub := models.TaskAiSub{ - TaskId: i.taskId, - TaskName: task.Name, - TaskAiId: task.TaskId, - TaskAiName: task.Name, - ImageName: r.ImageName, - Result: r.ImageResult, - Card: r.Card, - ClusterId: task.ClusterId, - ClusterName: r.ClusterName, - } - err := i.storage.SaveAiTaskImageSubTask(&taskAiSub) - if err != nil { - panic(err) - } - } - } - } - - // update succeeded cluster status - var successStatusCount int - for _, c := range cs { - for _, t := range i.aiTaskList { - if c.clusterId == strconv.Itoa(int(t.ClusterId)) { - t.Status = constants.Completed - t.EndTime = time.Now().Format(time.RFC3339) - err := i.storage.UpdateAiTask(t) - if err != nil { - logx.Errorf(err.Error()) - } - successStatusCount++ - } else { - continue - } - } - } - - if len(cs) == successStatusCount { - i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "completed", "任务完成") - } else { - i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") - } - - return results, nil -} - -func (i *ImageClassification) updateStatus(cs []*FilteredCluster) error { - - //no cluster available - if len(cs) == 0 { - for _, t := range i.aiTaskList { - t.Status = constants.Failed - t.EndTime = time.Now().Format(time.RFC3339) - if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { - t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] - } - err := i.storage.UpdateAiTask(t) - if err != nil { - logx.Errorf(err.Error()) - } - } - i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") - return errors.New("image infer task failed") - } - - //change cluster status - if len(i.clusters) != len(cs) { - var acs []*strategy.AssignedCluster - var rcs []*strategy.AssignedCluster - for _, cluster := range i.clusters { - if contains(cs, cluster.ClusterId) { - var ac *strategy.AssignedCluster - ac = cluster - rcs = append(rcs, ac) - } else { - var ac *strategy.AssignedCluster - ac = cluster - acs = append(acs, ac) - } - } - - // update failed cluster status - for _, ac := range acs { - for _, t := range i.aiTaskList { - if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { - t.Status = constants.Failed - t.EndTime = time.Now().Format(time.RFC3339) - if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { - t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] - } - err := i.storage.UpdateAiTask(t) - if err != nil { - logx.Errorf(err.Error()) - } - } - } - } - - // update running cluster status - for _, ac := range rcs { - for _, t := range i.aiTaskList { - if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { - t.Status = constants.Running - err := i.storage.UpdateAiTask(t) - if err != nil { - logx.Errorf(err.Error()) - } - } - } - } - i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") - } else { - for _, t := range i.aiTaskList { - t.Status = constants.Running - err := i.storage.UpdateAiTask(t) - if err != nil { - logx.Errorf(err.Error()) - } - } - i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "running", "任务运行中") +func (ic *ImageClassification) AppendRoute(urls []*inference.InferUrl) error { + for i, _ := range urls { + urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + CLASSIFICATION } return nil } -func sendInferReq(images []*ImageFile, cluster *FilteredCluster, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { - for _, image := range images { - limit <- true - go func(t *ImageFile, c *FilteredCluster) { - if len(c.urls) == 1 { - r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) - if err != nil { - t.ImageResult.ImageResult = err.Error() - t.ImageResult.ClusterId = c.clusterId - t.ImageResult.ClusterName = c.clusterName - t.ImageResult.Card = c.urls[0].Card - ch <- t.ImageResult - wg.Done() - <-limit - return - } - t.ImageResult.ImageResult = r - t.ImageResult.ClusterId = c.clusterId - t.ImageResult.ClusterName = c.clusterName - t.ImageResult.Card = c.urls[0].Card - - ch <- t.ImageResult - wg.Done() - <-limit - return - } else { - idx := rand.Intn(len(c.urls)) - r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) - if err != nil { - t.ImageResult.ImageResult = err.Error() - t.ImageResult.ClusterId = c.clusterId - t.ImageResult.ClusterName = c.clusterName - t.ImageResult.Card = c.urls[idx].Card - ch <- t.ImageResult - wg.Done() - <-limit - return - } - t.ImageResult.ImageResult = r - t.ImageResult.ClusterId = c.clusterId - t.ImageResult.ClusterName = c.clusterName - t.ImageResult.Card = c.urls[idx].Card - - ch <- t.ImageResult - wg.Done() - <-limit - return - } - }(image, cluster) - <-limit - } -} - -func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { - if clusterName == "鹏城云脑II-modelarts" { - r, err := getInferResultModelarts(url, file, fileName) - if err != nil { - return "", err - } - return r, nil - } - var res Res - req := GetRestyRequest(20) - _, err := req. - SetFileReader("file", fileName, file). - SetResult(&res). - Post(url) - if err != nil { - return "", err - } - return res.Result, nil -} - -func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { - var res Res - /* req := GetRestyRequest(20) - _, err := req. - SetFileReader("file", fileName, file). - SetHeaders(map[string]string{ - "ak": "UNEHPHO4Z7YSNPKRXFE4", - "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", - }). - SetResult(&res). - Post(url) - if err != nil { - return "", err - }*/ - body, err := utils.SendRequest("POST", url, file, fileName) - if err != nil { - return "", err - } - errjson := json.Unmarshal([]byte(body), &res) - if errjson != nil { - log.Fatalf("Error parsing JSON: %s", errjson) - } - return res.Result, nil -} - -func GetRestyRequest(timeoutSeconds int64) *resty.Request { - client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) - request := client.R() - return request -} - -type Res struct { - Result string `json:"result"` -} - -func contains(cs []*FilteredCluster, e string) bool { - for _, c := range cs { - if c.clusterId == e { - return true - } - } - return false +func (ic *ImageClassification) GetAiType() string { + return CLASSIFICATION_AiTYPE } diff --git a/internal/scheduler/service/inference/imageInference/imageInference.go b/internal/scheduler/service/inference/imageInference/imageInference.go index e82da39a..7ac6ec66 100644 --- a/internal/scheduler/service/inference/imageInference/imageInference.go +++ b/internal/scheduler/service/inference/imageInference/imageInference.go @@ -1,9 +1,469 @@ package imageInference -import "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" +import ( + "encoding/json" + "errors" + "github.com/go-resty/resty/v2" + "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" + "log" + "math/rand" + "mime/multipart" + "net/http" + "sort" + "strconv" + "sync" + "time" +) -type ImageInference interface { - filterClusters() ([]*FilteredCluster, error) - inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) - updateStatus(cs []*FilteredCluster) error +type IImageInference interface { + AppendRoute(urls []*inference.InferUrl) error + GetAiType() string +} + +type ImageFile struct { + ImageResult *types.ImageResult + File multipart.File +} + +type FilteredCluster struct { + urls []*inference.InferUrl + clusterId string + clusterName string + imageNum int32 +} + +type ImageInference struct { + inference IImageInference + files []*ImageFile + clusters []*strategy.AssignedCluster + opt *option.InferOption + storage *database.AiStorage + inferAdapter map[string]map[string]inference.ICluster + errMap map[string]string + adapterName string +} + +func New( + inference IImageInference, + files []*ImageFile, + clusters []*strategy.AssignedCluster, + opt *option.InferOption, + storage *database.AiStorage, + inferAdapter map[string]map[string]inference.ICluster, + adapterName string) (*ImageInference, error) { + + return &ImageInference{ + inference: inference, + files: files, + clusters: clusters, + opt: opt, + storage: storage, + inferAdapter: inferAdapter, + adapterName: adapterName, + errMap: make(map[string]string), + }, nil +} + +func (i *ImageInference) CreateTask() (int64, error) { + id, err := i.saveTask() + if err != nil { + return 0, err + } + err = i.saveAiTask(id) + if err != nil { + return 0, err + } + return id, nil +} + +func (i *ImageInference) InferTask(id int64) error { + clusters, err := i.filterClusters() + if err != nil { + return err + } + + aiTaskList, err := i.storage.GetAiTaskListById(id) + if err != nil || len(aiTaskList) == 0 { + return err + } + + err = i.updateStatus(aiTaskList, clusters) + if err != nil { + return err + } + results, err := i.inferImages(clusters) + if err != nil { + return err + } + err = i.saveAiSubTasks(id, aiTaskList, clusters, results) + if err != nil { + return err + } + return nil +} + +func (i *ImageInference) saveTask() (int64, error) { + var synergystatus int64 + if len(i.clusters) > 1 { + synergystatus = 1 + } + + strategyCode, err := i.storage.GetStrategyCode(i.opt.Strategy) + if err != nil { + return 0, err + } + + id, err := i.storage.SaveTask(i.opt.TaskName, strategyCode, synergystatus, i.inference.GetAiType()) + if err != nil { + return 0, err + } + + i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "create", "任务创建中") + + return id, nil +} + +func (i *ImageInference) saveAiTask(id int64) error { + for _, c := range i.clusters { + clusterName, _ := i.storage.GetClusterNameById(c.ClusterId) + i.opt.Replica = c.Replicas + err := i.storage.SaveAiTask(id, i.opt, i.adapterName, c.ClusterId, clusterName, "", constants.Saved, "") + if err != nil { + return err + } + } + return nil +} + +func (i *ImageInference) filterClusters() ([]*FilteredCluster, error) { + var wg sync.WaitGroup + var ch = make(chan *FilteredCluster, len(i.clusters)) + var cs []*FilteredCluster + var mutex sync.Mutex + + inferMap := i.inferAdapter[i.opt.AdapterId] + + for _, cluster := range i.clusters { + wg.Add(1) + c := cluster + go func() { + r := http.Request{} + imageUrls, err := inferMap[c.ClusterId].GetInferUrl(r.Context(), i.opt) + if err != nil { + mutex.Lock() + i.errMap[c.ClusterId] = err.Error() + mutex.Unlock() + wg.Done() + return + } + + i.inference.AppendRoute(imageUrls) + + clusterName, _ := i.storage.GetClusterNameById(c.ClusterId) + + var f FilteredCluster + f.urls = imageUrls + f.clusterId = c.ClusterId + f.clusterName = clusterName + f.imageNum = c.Replicas + + ch <- &f + wg.Done() + return + }() + } + wg.Wait() + close(ch) + + for s := range ch { + cs = append(cs, s) + } + return cs, nil +} + +func (i *ImageInference) inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) { + var wg sync.WaitGroup + var ch = make(chan *types.ImageResult, len(i.files)) + var results []*types.ImageResult + limit := make(chan bool, 7) + + var imageNumIdx int32 = 0 + var imageNumIdxEnd int32 = 0 + for _, c := range cs { + new_images := make([]*ImageFile, len(i.files)) + copy(new_images, i.files) + + imageNumIdxEnd = imageNumIdxEnd + c.imageNum + new_images = new_images[imageNumIdx:imageNumIdxEnd] + imageNumIdx = imageNumIdx + c.imageNum + + wg.Add(len(new_images)) + go sendInferReq(new_images, c, &wg, ch, limit) + } + wg.Wait() + close(ch) + + for s := range ch { + results = append(results, s) + } + + sort.Slice(results, func(p, q int) bool { + return results[p].ClusterName < results[q].ClusterName + }) + + return results, nil +} + +func (i *ImageInference) updateStatus(aiTaskList []*models.TaskAi, cs []*FilteredCluster) error { + + //no cluster available + if len(cs) == 0 { + for _, t := range aiTaskList { + t.Status = constants.Failed + t.EndTime = time.Now().Format(time.RFC3339) + if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { + t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] + } + err := i.storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") + return errors.New("available clusters' empty, image infer task failed") + } + + //change cluster status + if len(i.clusters) != len(cs) { + var acs []*strategy.AssignedCluster + var rcs []*strategy.AssignedCluster + for _, cluster := range i.clusters { + if contains(cs, cluster.ClusterId) { + var ac *strategy.AssignedCluster + ac = cluster + rcs = append(rcs, ac) + } else { + var ac *strategy.AssignedCluster + ac = cluster + acs = append(acs, ac) + } + } + + // update failed cluster status + for _, ac := range acs { + for _, t := range aiTaskList { + if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Failed + t.EndTime = time.Now().Format(time.RFC3339) + if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok { + t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))] + } + err := i.storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + } + } + + // update running cluster status + for _, ac := range rcs { + for _, t := range aiTaskList { + if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Running + err := i.storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + } + } + i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") + } else { + for _, t := range aiTaskList { + t.Status = constants.Running + err := i.storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + } + i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "running", "任务运行中") + } + return nil +} + +func sendInferReq(images []*ImageFile, cluster *FilteredCluster, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { + for _, image := range images { + limit <- true + go func(t *ImageFile, c *FilteredCluster) { + if len(c.urls) == 1 { + r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) + if err != nil { + t.ImageResult.ImageResult = err.Error() + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[0].Card + ch <- t.ImageResult + wg.Done() + <-limit + return + } + t.ImageResult.ImageResult = r + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[0].Card + + ch <- t.ImageResult + wg.Done() + <-limit + return + } else { + idx := rand.Intn(len(c.urls)) + r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) + if err != nil { + t.ImageResult.ImageResult = err.Error() + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[idx].Card + ch <- t.ImageResult + wg.Done() + <-limit + return + } + t.ImageResult.ImageResult = r + t.ImageResult.ClusterId = c.clusterId + t.ImageResult.ClusterName = c.clusterName + t.ImageResult.Card = c.urls[idx].Card + + ch <- t.ImageResult + wg.Done() + <-limit + return + } + }(image, cluster) + <-limit + } +} + +func (i *ImageInference) saveAiSubTasks(id int64, aiTaskList []*models.TaskAi, cs []*FilteredCluster, results []*types.ImageResult) error { + //save ai sub tasks + for _, r := range results { + for _, task := range aiTaskList { + if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { + taskAiSub := models.TaskAiSub{ + TaskId: id, + TaskName: task.Name, + TaskAiId: task.TaskId, + TaskAiName: task.Name, + ImageName: r.ImageName, + Result: r.ImageResult, + Card: r.Card, + ClusterId: task.ClusterId, + ClusterName: r.ClusterName, + } + err := i.storage.SaveAiTaskImageSubTask(&taskAiSub) + if err != nil { + panic(err) + } + } + } + } + + // update succeeded cluster status + var successStatusCount int + for _, c := range cs { + for _, t := range aiTaskList { + if c.clusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Completed + t.EndTime = time.Now().Format(time.RFC3339) + err := i.storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + } + successStatusCount++ + } else { + continue + } + } + } + + if len(cs) == successStatusCount { + i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "completed", "任务完成") + } else { + i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败") + } + + return nil +} + +func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { + if clusterName == "鹏城云脑II-modelarts" { + r, err := getInferResultModelarts(url, file, fileName) + if err != nil { + return "", err + } + return r, nil + } + var res Res + req := GetRestyRequest(20) + _, err := req. + SetFileReader("file", fileName, file). + SetResult(&res). + Post(url) + if err != nil { + return "", err + } + return res.Result, nil +} + +func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { + var res Res + /* req := GetRestyRequest(20) + _, err := req. + SetFileReader("file", fileName, file). + SetHeaders(map[string]string{ + "ak": "UNEHPHO4Z7YSNPKRXFE4", + "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", + }). + SetResult(&res). + Post(url) + if err != nil { + return "", err + }*/ + body, err := utils.SendRequest("POST", url, file, fileName) + if err != nil { + return "", err + } + errjson := json.Unmarshal([]byte(body), &res) + if errjson != nil { + log.Fatalf("Error parsing JSON: %s", errjson) + } + return res.Result, nil +} + +func GetRestyRequest(timeoutSeconds int64) *resty.Request { + client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) + request := client.R() + return request +} + +type Res struct { + Result string `json:"result"` +} + +func contains(cs []*FilteredCluster, e string) bool { + for _, c := range cs { + if c.clusterId == e { + return true + } + } + return false } diff --git a/internal/scheduler/service/inference/imageInference/imageToText.go b/internal/scheduler/service/inference/imageInference/imageToText.go index 66c68e99..84d10643 100644 --- a/internal/scheduler/service/inference/imageInference/imageToText.go +++ b/internal/scheduler/service/inference/imageInference/imageToText.go @@ -1,19 +1,22 @@ package imageInference -import ( - "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" - "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" - "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" - "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" +import "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" + +const ( + IMAGETOTEXT = "image-to-text" + IMAGETOTEXT_AiTYPE = "13" ) type ImageToText struct { - files []*ImageFile - clusters []*strategy.AssignedCluster - opt *option.InferOption - storage *database.AiStorage - inferAdapter map[string]map[string]inference.Inference - errMap map[string]string - taskId int64 - adapterName string +} + +func (it *ImageToText) AppendRoute(urls []*inference.InferUrl) error { + for i, _ := range urls { + urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + IMAGETOTEXT + } + return nil +} + +func (it *ImageToText) GetAiType() string { + return IMAGETOTEXT_AiTYPE } diff --git a/internal/scheduler/service/inference/inference.go b/internal/scheduler/service/inference/inference.go index c962778c..88539a31 100644 --- a/internal/scheduler/service/inference/inference.go +++ b/internal/scheduler/service/inference/inference.go @@ -5,377 +5,25 @@ import ( "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" ) -type Inference interface { +const ( + FORWARD_SLASH = "/" +) + +type ICluster interface { GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error) //GetInferDeployInstanceList(ctx context.Context, option *option.InferOption) } +type IInference interface { + CreateTask() (int64, error) + InferTask(id int64) error +} + +type Inference struct { + In IInference +} + type InferUrl struct { Url string Card string } - -//func ImageInfer(opt *option.InferOption, id int64, adapterName string, clusters []*strategy.AssignedCluster, ts []*ImageFile, inferAdapterMap map[string]map[string]Inference, storage *database.AiStorage, ctx context.Context) ([]*types.ImageResult, error) { -// -// //for i := len(clusters) - 1; i >= 0; i-- { -// // if clusters[i].Replicas == 0 { -// // clusters = append(clusters[:i], clusters[i+1:]...) -// // } -// //} -// var wg sync.WaitGroup -// var cluster_ch = make(chan struct { -// urls []*InferUrl -// clusterId string -// clusterName string -// imageNum int32 -// }, len(clusters)) -// -// var cs []struct { -// urls []*InferUrl -// clusterId string -// clusterName string -// imageNum int32 -// } -// inferMap := inferAdapterMap[opt.AdapterId] -// -// ////save taskai -// //for _, c := range clusters { -// // clusterName, _ := storage.GetClusterNameById(c.ClusterId) -// // opt.Replica = c.Replicas -// // err := storage.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") -// // if err != nil { -// // return nil, err -// // } -// //} -// -// var mutex sync.Mutex -// errMap := make(map[string]string) -// for _, cluster := range clusters { -// wg.Add(1) -// c := cluster -// go func() { -// imageUrls, err := inferMap[c.ClusterId].GetInferUrl(ctx, opt) -// if err != nil { -// mutex.Lock() -// errMap[c.ClusterId] = err.Error() -// mutex.Unlock() -// wg.Done() -// return -// } -// for i, _ := range imageUrls { -// imageUrls[i].Url = imageUrls[i].Url + "/" + "image" -// } -// clusterName, _ := storage.GetClusterNameById(c.ClusterId) -// -// s := struct { -// urls []*InferUrl -// clusterId string -// clusterName string -// imageNum int32 -// }{ -// urls: imageUrls, -// clusterId: c.ClusterId, -// clusterName: clusterName, -// imageNum: c.Replicas, -// } -// -// cluster_ch <- s -// wg.Done() -// return -// }() -// } -// wg.Wait() -// close(cluster_ch) -// -// for s := range cluster_ch { -// cs = append(cs, s) -// } -// -// aiTaskList, err := storage.GetAiTaskListById(id) -// if err != nil { -// return nil, err -// } -// -// //no cluster available -// if len(cs) == 0 { -// for _, t := range aiTaskList { -// t.Status = constants.Failed -// t.EndTime = time.Now().Format(time.RFC3339) -// if _, ok := errMap[strconv.Itoa(int(t.ClusterId))]; ok { -// t.Msg = errMap[strconv.Itoa(int(t.ClusterId))] -// } -// err := storage.UpdateAiTask(t) -// if err != nil { -// logx.Errorf(err.Error()) -// } -// } -// storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") -// return nil, errors.New("image infer task failed") -// } -// -// //change cluster status -// if len(clusters) != len(cs) { -// var acs []*strategy.AssignedCluster -// var rcs []*strategy.AssignedCluster -// for _, cluster := range clusters { -// if contains(cs, cluster.ClusterId) { -// var ac *strategy.AssignedCluster -// ac = cluster -// rcs = append(rcs, ac) -// } else { -// var ac *strategy.AssignedCluster -// ac = cluster -// acs = append(acs, ac) -// } -// } -// -// // update failed cluster status -// for _, ac := range acs { -// for _, t := range aiTaskList { -// if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { -// t.Status = constants.Failed -// t.EndTime = time.Now().Format(time.RFC3339) -// if _, ok := errMap[strconv.Itoa(int(t.ClusterId))]; ok { -// t.Msg = errMap[strconv.Itoa(int(t.ClusterId))] -// } -// err := storage.UpdateAiTask(t) -// if err != nil { -// logx.Errorf(err.Error()) -// } -// } -// } -// } -// -// // update running cluster status -// for _, ac := range rcs { -// for _, t := range aiTaskList { -// if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { -// t.Status = constants.Running -// err := storage.UpdateAiTask(t) -// if err != nil { -// logx.Errorf(err.Error()) -// } -// } -// } -// } -// storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") -// } else { -// for _, t := range aiTaskList { -// t.Status = constants.Running -// err := storage.UpdateAiTask(t) -// if err != nil { -// logx.Errorf(err.Error()) -// } -// } -// storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中") -// } -// -// var result_ch = make(chan *types.ImageResult, len(ts)) -// var results []*types.ImageResult -// limit := make(chan bool, 7) -// -// var imageNumIdx int32 = 0 -// var imageNumIdxEnd int32 = 0 -// for _, c := range cs { -// new_images := make([]*ImageFile, len(ts)) -// copy(new_images, ts) -// -// imageNumIdxEnd = imageNumIdxEnd + c.imageNum -// new_images = new_images[imageNumIdx:imageNumIdxEnd] -// imageNumIdx = imageNumIdx + c.imageNum -// -// wg.Add(len(new_images)) -// go sendInferReq(new_images, c, &wg, result_ch, limit) -// } -// wg.Wait() -// close(result_ch) -// -// for s := range result_ch { -// results = append(results, s) -// } -// -// sort.Slice(results, func(p, q int) bool { -// return results[p].ClusterName < results[q].ClusterName -// }) -// -// //save ai sub tasks -// for _, r := range results { -// for _, task := range aiTaskList { -// if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { -// taskAiSub := models.TaskAiSub{ -// TaskId: id, -// TaskName: task.Name, -// TaskAiId: task.TaskId, -// TaskAiName: task.Name, -// ImageName: r.ImageName, -// Result: r.ImageResult, -// Card: r.Card, -// ClusterId: task.ClusterId, -// ClusterName: r.ClusterName, -// } -// err := storage.SaveAiTaskImageSubTask(&taskAiSub) -// if err != nil { -// panic(err) -// } -// } -// } -// } -// -// // update succeeded cluster status -// var successStatusCount int -// for _, c := range cs { -// for _, t := range aiTaskList { -// if c.clusterId == strconv.Itoa(int(t.ClusterId)) { -// t.Status = constants.Completed -// t.EndTime = time.Now().Format(time.RFC3339) -// err := storage.UpdateAiTask(t) -// if err != nil { -// logx.Errorf(err.Error()) -// } -// successStatusCount++ -// } else { -// continue -// } -// } -// } -// -// if len(cs) == successStatusCount { -// storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") -// } else { -// storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") -// } -// -// return results, nil -//} -// -//func sendInferReq(images []*ImageFile, cluster struct { -// urls []*InferUrl -// clusterId string -// clusterName string -// imageNum int32 -//}, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) { -// for _, image := range images { -// limit <- true -// go func(t *ImageFile, c struct { -// urls []*InferUrl -// clusterId string -// clusterName string -// imageNum int32 -// }) { -// if len(c.urls) == 1 { -// r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName) -// if err != nil { -// t.ImageResult.ImageResult = err.Error() -// t.ImageResult.ClusterId = c.clusterId -// t.ImageResult.ClusterName = c.clusterName -// t.ImageResult.Card = c.urls[0].Card -// ch <- t.ImageResult -// wg.Done() -// <-limit -// return -// } -// t.ImageResult.ImageResult = r -// t.ImageResult.ClusterId = c.clusterId -// t.ImageResult.ClusterName = c.clusterName -// t.ImageResult.Card = c.urls[0].Card -// -// ch <- t.ImageResult -// wg.Done() -// <-limit -// return -// } else { -// idx := rand.Intn(len(c.urls)) -// r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName) -// if err != nil { -// t.ImageResult.ImageResult = err.Error() -// t.ImageResult.ClusterId = c.clusterId -// t.ImageResult.ClusterName = c.clusterName -// t.ImageResult.Card = c.urls[idx].Card -// ch <- t.ImageResult -// wg.Done() -// <-limit -// return -// } -// t.ImageResult.ImageResult = r -// t.ImageResult.ClusterId = c.clusterId -// t.ImageResult.ClusterName = c.clusterName -// t.ImageResult.Card = c.urls[idx].Card -// -// ch <- t.ImageResult -// wg.Done() -// <-limit -// return -// } -// }(image, cluster) -// <-limit -// } -//} -// -//func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { -// if clusterName == "鹏城云脑II-modelarts" { -// r, err := getInferResultModelarts(url, file, fileName) -// if err != nil { -// return "", err -// } -// return r, nil -// } -// var res Res -// req := GetRestyRequest(20) -// _, err := req. -// SetFileReader("file", fileName, file). -// SetResult(&res). -// Post(url) -// if err != nil { -// return "", err -// } -// return res.Result, nil -//} -// -//func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { -// var res Res -// /* req := GetRestyRequest(20) -// _, err := req. -// SetFileReader("file", fileName, file). -// SetHeaders(map[string]string{ -// "ak": "UNEHPHO4Z7YSNPKRXFE4", -// "sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", -// }). -// SetResult(&res). -// Post(url) -// if err != nil { -// return "", err -// }*/ -// body, err := utils.SendRequest("POST", url, file, fileName) -// if err != nil { -// return "", err -// } -// errjson := json.Unmarshal([]byte(body), &res) -// if errjson != nil { -// log.Fatalf("Error parsing JSON: %s", errjson) -// } -// return res.Result, nil -//} -// -//func GetRestyRequest(timeoutSeconds int64) *resty.Request { -// client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) -// request := client.R() -// return request -//} -// -//type Res struct { -// Result string `json:"result"` -//} -// -//func contains(cs []struct { -// urls []*InferUrl -// clusterId string -// clusterName string -// imageNum int32 -//}, e string) bool { -// for _, c := range cs { -// if c.clusterId == e { -// return true -// } -// } -// return false -//} diff --git a/internal/scheduler/service/inference/textInference/textInference.go b/internal/scheduler/service/inference/textInference/textInference.go index 29d62ad7..42afa5d7 100644 --- a/internal/scheduler/service/inference/textInference/textInference.go +++ b/internal/scheduler/service/inference/textInference/textInference.go @@ -1 +1,97 @@ package textInference + +import ( + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" +) + +type ITextInference interface { + SaveAiTask(id int64, adapterName string) error + UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error + AppendRoute(urls []*inference.InferUrl) error + AiType() string +} + +type FilteredCluster struct { + urls []*inference.InferUrl + clusterId string + clusterName string +} + +type TextInference struct { + inference ITextInference + opt *option.InferOption + storage *database.AiStorage + inferAdapter map[string]map[string]inference.ICluster + errMap map[string]string + adapterName string +} + +func New( + inference ITextInference, + opt *option.InferOption, + storage *database.AiStorage, + inferAdapter map[string]map[string]inference.ICluster, + adapterName string) (*TextInference, error) { + return &TextInference{ + inference: inference, + opt: opt, + storage: storage, + inferAdapter: inferAdapter, + adapterName: adapterName, + errMap: make(map[string]string), + }, nil +} + +func (ti *TextInference) CreateTask() (int64, error) { + id, err := ti.saveTask() + if err != nil { + return 0, err + } + err = ti.saveAiTask(id) + if err != nil { + return 0, err + } + return id, nil +} + +func (ti *TextInference) InferTask(id int64) error { + aiTaskList, err := ti.storage.GetAiTaskListById(id) + if err != nil || len(aiTaskList) == 0 { + return err + } + err = ti.updateStatus(aiTaskList) + if err != nil { + return err + } + return nil +} + +func (ti *TextInference) saveTask() (int64, error) { + var synergystatus int64 + var strategyCode int64 + + id, err := ti.storage.SaveTask(ti.opt.TaskName, strategyCode, synergystatus, ti.inference.AiType()) + if err != nil { + return 0, err + } + return id, nil +} + +func (ti *TextInference) saveAiTask(id int64) error { + err := ti.inference.SaveAiTask(id, ti.adapterName) + if err != nil { + return err + } + return nil +} + +func (ti *TextInference) updateStatus(aiTaskList []*models.TaskAi) error { + err := ti.inference.UpdateStatus(aiTaskList, ti.adapterName) + if err != nil { + return err + } + return nil +} diff --git a/internal/scheduler/service/inference/textInference/textToImage.go b/internal/scheduler/service/inference/textInference/textToImage.go new file mode 100644 index 00000000..fc1e41c4 --- /dev/null +++ b/internal/scheduler/service/inference/textInference/textToImage.go @@ -0,0 +1,48 @@ +package textInference + +import ( + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" +) + +const ( + TEXTTOIMAGE = "text-to-image" + TEXTTOIMAGE_AiTYPE = "14" +) + +type TextToImage struct { + clusters []*strategy.AssignedCluster + storage *database.AiStorage + opt *option.InferOption +} + +func (t *TextToImage) SaveAiTask(id int64, adapterName string) error { + for _, c := range t.clusters { + clusterName, _ := t.storage.GetClusterNameById(c.ClusterId) + t.opt.Replica = c.Replicas + err := t.storage.SaveAiTask(id, t.opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") + if err != nil { + return err + } + } + return nil +} + +func (t *TextToImage) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error { + return nil +} + +func (t *TextToImage) AppendRoute(urls []*inference.InferUrl) error { + for i, _ := range urls { + urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + TEXTTOIMAGE + } + return nil +} + +func (t *TextToImage) AiType() string { + return TEXTTOIMAGE_AiTYPE +} diff --git a/internal/scheduler/service/inference/textInference/textToText.go b/internal/scheduler/service/inference/textInference/textToText.go new file mode 100644 index 00000000..0319bda7 --- /dev/null +++ b/internal/scheduler/service/inference/textInference/textToText.go @@ -0,0 +1,131 @@ +package textInference + +import ( + "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" + "net/http" + "strconv" + "sync" + "time" +) + +const ( + CHAT = "chat" + TEXTTOTEXT_AITYPE = "12" +) + +type TextToText struct { + opt *option.InferOption + storage *database.AiStorage + inferAdapter map[string]map[string]inference.ICluster + cs []*FilteredCluster +} + +func NewTextToText(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) (*TextToText, error) { + cs, err := filterClusters(opt, storage, inferAdapter) + if err != nil { + return nil, err + } + return &TextToText{ + opt: opt, + storage: storage, + inferAdapter: inferAdapter, + cs: cs, + }, nil +} + +func (tt *TextToText) AppendRoute(urls []*inference.InferUrl) error { + for i, _ := range urls { + urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + CHAT + } + return nil +} + +func (tt *TextToText) AiType() string { + return TEXTTOTEXT_AITYPE +} + +func (tt *TextToText) SaveAiTask(id int64, adapterName string) error { + + if len(tt.cs) == 0 { + clusterId := tt.opt.AiClusterIds[0] + clusterName, _ := tt.storage.GetClusterNameById(tt.opt.AiClusterIds[0]) + err := tt.storage.SaveAiTask(id, tt.opt, adapterName, clusterId, clusterName, "", constants.Failed, "") + if err != nil { + return err + } + tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "failed", "任务失败") + } + + for _, c := range tt.cs { + clusterName, _ := tt.storage.GetClusterNameById(c.clusterId) + err := tt.storage.SaveAiTask(id, tt.opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "") + if err != nil { + return err + } + } + return nil +} + +func filterClusters(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) ([]*FilteredCluster, error) { + var wg sync.WaitGroup + var ch = make(chan *FilteredCluster, len(opt.AiClusterIds)) + var cs []*FilteredCluster + inferMap := inferAdapter[opt.AdapterId] + + for _, clusterId := range opt.AiClusterIds { + wg.Add(1) + go func(cId string) { + r := http.Request{} + urls, err := inferMap[cId].GetInferUrl(r.Context(), opt) + if err != nil { + wg.Done() + return + } + for i, _ := range urls { + urls[i].Url = urls[i].Url + inference.FORWARD_SLASH + CHAT + } + clusterName, _ := storage.GetClusterNameById(cId) + + var f FilteredCluster + f.urls = urls + f.clusterId = cId + f.clusterName = clusterName + + ch <- &f + wg.Done() + return + }(clusterId) + } + wg.Wait() + close(ch) + + for s := range ch { + cs = append(cs, s) + } + + return cs, nil +} + +func (tt *TextToText) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error { + for i, t := range aiTaskList { + if strconv.Itoa(int(t.ClusterId)) == tt.cs[i].clusterId { + t.Status = constants.Completed + t.EndTime = time.Now().Format(time.RFC3339) + url := tt.cs[i].urls[0].Url + t.InferUrl = url + err := tt.storage.UpdateAiTask(t) + if err != nil { + logx.Errorf(err.Error()) + return err + } + } + } + + tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "completed", "任务完成") + return nil +}