From 1ee960f52ada14d4357d7c2017ff8c3d475c6bbf Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 17:06:40 +0800 Subject: [PATCH 1/3] added textinference api Former-commit-id: d32678e386f68034fa9f932975c2d09c6c825e2c --- api/desc/inference/inference.api | 12 +++++++- api/desc/pcm.api | 3 ++ .../inference/texttotextinferencehandler.go | 28 +++++++++++++++++ api/internal/handler/routes.go | 5 ++++ .../inference/texttotextinferencelogic.go | 30 +++++++++++++++++++ api/internal/types/types.go | 12 ++++++++ 6 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 api/internal/handler/inference/texttotextinferencehandler.go create mode 100644 api/internal/logic/inference/texttotextinferencelogic.go diff --git a/api/desc/inference/inference.api b/api/desc/inference/inference.api index ba9b1946..1872d2c6 100644 --- a/api/desc/inference/inference.api +++ b/api/desc/inference/inference.api @@ -63,7 +63,17 @@ type ( clusterName string `json:"clusterName"` } + /******************TextToText inference*************************/ + TextToTextInferenceReq{ + TaskName string `form:"taskName"` + TaskDesc string `form:"taskDesc"` + ModelName string `form:"modelName"` + ModelType string `form:"modelType"` + AdapterId string `form:"adapterId"` + ClusterId string `form:"clusterIds"` + } + TextToTextInferenceResp{ - + } ) diff --git a/api/desc/pcm.api b/api/desc/pcm.api index b912fda9..0a7b51bf 100644 --- a/api/desc/pcm.api +++ b/api/desc/pcm.api @@ -907,6 +907,9 @@ service pcm { group: inference ) service pcm { + @handler TextToTextInferenceHandler + post /inference/text (TextToTextInferenceReq) returns (TextToTextInferenceResp) + @handler ImageInferenceHandler post /inference/images (ImageInferenceReq) returns (ImageInferenceResp) diff --git a/api/internal/handler/inference/texttotextinferencehandler.go b/api/internal/handler/inference/texttotextinferencehandler.go new file mode 100644 index 00000000..55df289c --- /dev/null +++ b/api/internal/handler/inference/texttotextinferencehandler.go @@ -0,0 +1,28 @@ +package inference + +import ( + "net/http" + + "github.com/zeromicro/go-zero/rest/httpx" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" +) + +func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req types.TextToTextInferenceReq + if err := httpx.Parse(r, &req); err != nil { + httpx.ErrorCtx(r.Context(), w, err) + return + } + + l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx) + resp, err := l.TextToTextInference(&req) + if err != nil { + httpx.ErrorCtx(r.Context(), w, err) + } else { + httpx.OkJsonCtx(r.Context(), w, resp) + } + } +} diff --git a/api/internal/handler/routes.go b/api/internal/handler/routes.go index e04c7750..2daff30f 100644 --- a/api/internal/handler/routes.go +++ b/api/internal/handler/routes.go @@ -1138,6 +1138,11 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { server.AddRoutes( []rest.Route{ + { + Method: http.MethodPost, + Path: "/inference/text", + Handler: inference.TextToTextInferenceHandler(serverCtx), + }, { Method: http.MethodPost, Path: "/inference/images", diff --git a/api/internal/logic/inference/texttotextinferencelogic.go b/api/internal/logic/inference/texttotextinferencelogic.go new file mode 100644 index 00000000..faf61301 --- /dev/null +++ b/api/internal/logic/inference/texttotextinferencelogic.go @@ -0,0 +1,30 @@ +package inference + +import ( + "context" + + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + + "github.com/zeromicro/go-zero/core/logx" +) + +type TextToTextInferenceLogic struct { + logx.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewTextToTextInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TextToTextInferenceLogic { + return &TextToTextInferenceLogic{ + Logger: logx.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) { + // todo: add your logic here and delete this line + + return +} diff --git a/api/internal/types/types.go b/api/internal/types/types.go index 86f165bd..b2a942da 100644 --- a/api/internal/types/types.go +++ b/api/internal/types/types.go @@ -5941,3 +5941,15 @@ type InferenceResult struct { Card string `json:"card"` ClusterName string `json:"clusterName"` } + +type TextToTextInferenceReq struct { + TaskName string `form:"taskName"` + TaskDesc string `form:"taskDesc"` + ModelName string `form:"modelName"` + ModelType string `form:"modelType"` + AdapterId string `form:"adapterId"` + ClusterId string `form:"clusterIds"` +} + +type TextToTextInferenceResp struct { +} From c9084ff0936ec4d08c59e6da7102efa13e4c4294 Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 17:12:37 +0800 Subject: [PATCH 2/3] added textinference api Former-commit-id: ac9509e26c6a33b5caf86b8a2e6ff19b1e8ae0f7 --- api/desc/inference/inference.api | 2 +- api/internal/types/types.go | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/api/desc/inference/inference.api b/api/desc/inference/inference.api index 1872d2c6..f2e88943 100644 --- a/api/desc/inference/inference.api +++ b/api/desc/inference/inference.api @@ -70,7 +70,7 @@ type ( ModelName string `form:"modelName"` ModelType string `form:"modelType"` AdapterId string `form:"adapterId"` - ClusterId string `form:"clusterIds"` + AiClusterIds []string `form:"aiClusterIds"` } TextToTextInferenceResp{ diff --git a/api/internal/types/types.go b/api/internal/types/types.go index b2a942da..7ab16536 100644 --- a/api/internal/types/types.go +++ b/api/internal/types/types.go @@ -5943,12 +5943,12 @@ type InferenceResult struct { } type TextToTextInferenceReq struct { - TaskName string `form:"taskName"` - TaskDesc string `form:"taskDesc"` - ModelName string `form:"modelName"` - ModelType string `form:"modelType"` - AdapterId string `form:"adapterId"` - ClusterId string `form:"clusterIds"` + TaskName string `form:"taskName"` + TaskDesc string `form:"taskDesc"` + ModelName string `form:"modelName"` + ModelType string `form:"modelType"` + AdapterId string `form:"adapterId"` + AiClusterIds []string `form:"aiClusterIds"` } type TextToTextInferenceResp struct { From e6b9d3d23b521a10623426436f0db40f2ee4457a Mon Sep 17 00:00:00 2001 From: tzwang Date: Tue, 25 Jun 2024 18:19:18 +0800 Subject: [PATCH 3/3] added textinfer api Former-commit-id: bfdce9025139d71bc3178b039756d579be1b450c --- .../inference/texttotextinferencehandler.go | 13 +- .../logic/inference/imageinferencelogic.go | 14 +- .../inference/texttotextinferencelogic.go | 120 +++++++++++++++++- .../scheduler/service/collector/collector.go | 4 +- api/internal/storeLink/modelarts.go | 6 +- api/internal/storeLink/octopus.go | 6 +- api/internal/storeLink/shuguangai.go | 6 +- api/internal/storeLink/storeLink.go | 1 + pkg/models/taskaimodel_gen.go | 1 + 9 files changed, 140 insertions(+), 31 deletions(-) diff --git a/api/internal/handler/inference/texttotextinferencehandler.go b/api/internal/handler/inference/texttotextinferencehandler.go index 55df289c..d291d60e 100644 --- a/api/internal/handler/inference/texttotextinferencehandler.go +++ b/api/internal/handler/inference/texttotextinferencehandler.go @@ -1,28 +1,25 @@ package inference import ( - "net/http" - "github.com/zeromicro/go-zero/rest/httpx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" + "net/http" ) func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.TextToTextInferenceReq if err := httpx.Parse(r, &req); err != nil { - httpx.ErrorCtx(r.Context(), w, err) + result.ParamErrorResult(r, w, err) return } l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx) resp, err := l.TextToTextInference(&req) - if err != nil { - httpx.ErrorCtx(r.Context(), w, err) - } else { - httpx.OkJsonCtx(r.Context(), w, resp) - } + result.HttpResult(r, w, resp, err) + } } diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index ec40d60d..48aaf97a 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -133,14 +133,14 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s var wg sync.WaitGroup var cluster_ch = make(chan struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 }, len(clusters)) var cs []struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 @@ -182,7 +182,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s wg.Add(1) c := cluster go func() { - imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) + imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt) if err != nil { wg.Done() return @@ -190,7 +190,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) s := struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 @@ -373,7 +373,7 @@ func sendInferReq(images []struct { imageResult *types.ImageResult file multipart.File }, cluster struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 @@ -384,7 +384,7 @@ func sendInferReq(images []struct { imageResult *types.ImageResult file multipart.File }, c struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 @@ -494,7 +494,7 @@ type Res struct { } func contains(cs []struct { - urls []*collector.ImageInferUrl + urls []*collector.InferUrl clusterId string clusterName string imageNum int32 diff --git a/api/internal/logic/inference/texttotextinferencelogic.go b/api/internal/logic/inference/texttotextinferencelogic.go index faf61301..2974d834 100644 --- a/api/internal/logic/inference/texttotextinferencelogic.go +++ b/api/internal/logic/inference/texttotextinferencelogic.go @@ -2,11 +2,18 @@ package inference import ( "context" - + "errors" + "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" - - "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" + "strconv" + "sync" + "time" ) type TextToTextInferenceLogic struct { @@ -24,7 +31,110 @@ func NewTextToTextInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext } func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) { - // todo: add your logic here and delete this line + resp = &types.TextToTextInferenceResp{} + opt := &option.InferOption{ + TaskName: req.TaskName, + TaskDesc: req.TaskDesc, + AdapterId: req.AdapterId, + AiClusterIds: req.AiClusterIds, + ModelName: req.ModelName, + ModelType: req.ModelType, + } - return + _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] + if !ok { + return nil, errors.New("AdapterId does not exist") + } + + //save task + var synergystatus int64 + var strategyCode int64 + adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) + if err != nil { + return nil, err + } + + id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12") + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + var cluster_ch = make(chan struct { + urls []*collector.InferUrl + clusterId string + clusterName string + }, len(opt.AiClusterIds)) + + var cs []struct { + urls []*collector.InferUrl + clusterId string + clusterName string + } + collectorMap := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] + + //save taskai + for _, clusterId := range opt.AiClusterIds { + wg.Add(1) + go func(cId string) { + urls, err := collectorMap[cId].GetInferUrl(l.ctx, opt) + if err != nil { + wg.Done() + return + } + clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId) + + s := struct { + urls []*collector.InferUrl + clusterId string + clusterName string + }{ + urls: urls, + clusterId: cId, + clusterName: clusterName, + } + + cluster_ch <- s + wg.Done() + return + }(clusterId) + } + wg.Wait() + close(cluster_ch) + + for s := range cluster_ch { + cs = append(cs, s) + } + + for _, c := range cs { + clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId) + err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "") + if err != nil { + return nil, err + } + } + + var aiTaskList []*models.TaskAi + tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) + if tx.Error != nil { + return nil, tx.Error + + } + + for i, t := range aiTaskList { + if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId { + t.Status = constants.Completed + t.EndTime = time.Now().Format(time.RFC3339) + url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat" + t.InferUrl = url + err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t) + if err != nil { + logx.Errorf(tx.Error.Error()) + } + } + } + + l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成") + + return resp, nil } diff --git a/api/internal/scheduler/service/collector/collector.go b/api/internal/scheduler/service/collector/collector.go index 1fa777ed..4c978628 100644 --- a/api/internal/scheduler/service/collector/collector.go +++ b/api/internal/scheduler/service/collector/collector.go @@ -15,10 +15,10 @@ type AiCollector interface { UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error GetComputeCards(ctx context.Context) ([]string, error) GetUserBalance(ctx context.Context) (float64, error) - GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*ImageInferUrl, error) + GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error) } -type ImageInferUrl struct { +type InferUrl struct { Url string Card string } diff --git a/api/internal/storeLink/modelarts.go b/api/internal/storeLink/modelarts.go index a2532730..331becb1 100644 --- a/api/internal/storeLink/modelarts.go +++ b/api/internal/storeLink/modelarts.go @@ -378,8 +378,8 @@ func (m *ModelArtsLink) generateAlgorithmId(ctx context.Context, option *option. return errors.New("failed to get AlgorithmId") } -func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { - var imageUrls []*collector.ImageInferUrl +func (m *ModelArtsLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) { + var imageUrls []*collector.InferUrl urlReq := &modelartsclient.ImageReasoningUrlReq{ ModelName: option.ModelName, Type: option.ModelType, @@ -389,7 +389,7 @@ func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.Inf if err != nil { return nil, err } - imageUrl := &collector.ImageInferUrl{ + imageUrl := &collector.InferUrl{ Url: urlResp.Url, Card: "npu", } diff --git a/api/internal/storeLink/octopus.go b/api/internal/storeLink/octopus.go index a902cf93..9c605345 100644 --- a/api/internal/storeLink/octopus.go +++ b/api/internal/storeLink/octopus.go @@ -871,7 +871,7 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec return errors.New("set ResourceId error") } -func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { +func (o *OctopusLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) { req := &octopus.GetNotebookListReq{ Platform: o.platform, PageIndex: o.pageIndex, @@ -882,12 +882,12 @@ func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.Infer return nil, err } - var imageUrls []*collector.ImageInferUrl + var imageUrls []*collector.InferUrl for _, notebook := range list.Payload.GetNotebooks() { if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" { url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1) names := strings.Split(notebook.AlgorithmName, UNDERSCORE) - imageUrl := &collector.ImageInferUrl{ + imageUrl := &collector.InferUrl{ Url: DOMAIN + url + FORWARD_SLASH + "image", Card: names[2], } diff --git a/api/internal/storeLink/shuguangai.go b/api/internal/storeLink/shuguangai.go index c0ca6d93..2a634a4c 100644 --- a/api/internal/storeLink/shuguangai.go +++ b/api/internal/storeLink/shuguangai.go @@ -730,8 +730,8 @@ func (s *ShuguangAi) generateParams(option *option.AiOption) error { return errors.New("failed to set params") } -func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { - var imageUrls []*collector.ImageInferUrl +func (s *ShuguangAi) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) { + var imageUrls []*collector.InferUrl urlReq := &hpcAC.GetInferUrlReq{ ModelName: option.ModelName, @@ -743,7 +743,7 @@ func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferO if err != nil { return nil, err } - imageUrl := &collector.ImageInferUrl{ + imageUrl := &collector.InferUrl{ Url: urlResp.Url, Card: "dcu", } diff --git a/api/internal/storeLink/storeLink.go b/api/internal/storeLink/storeLink.go index d5ead2e6..2cd54f59 100644 --- a/api/internal/storeLink/storeLink.go +++ b/api/internal/storeLink/storeLink.go @@ -78,6 +78,7 @@ var ( } ModelTypeMap = map[string][]string{ "image_recognition": {"imagenet_resnet50"}, + "text_to_text": {"chatGLM-6B"}, } AITYPE = map[string]string{ "1": OCTOPUS, diff --git a/pkg/models/taskaimodel_gen.go b/pkg/models/taskaimodel_gen.go index 6b0e07bc..24473016 100644 --- a/pkg/models/taskaimodel_gen.go +++ b/pkg/models/taskaimodel_gen.go @@ -54,6 +54,7 @@ type ( TaskType string `db:"task_type"` DeletedAt *time.Time `db:"deleted_at"` Card string `db:"card"` + InferUrl string `db:"infer_url"` } )