diff --git a/api/desc/inference/inference.api b/api/desc/inference/inference.api index c8b6b8da..86b3ad0e 100644 --- a/api/desc/inference/inference.api +++ b/api/desc/inference/inference.api @@ -2,6 +2,18 @@ syntax = "v1" type ( /******************image inference*************************/ + ModelTypesResp { + ModelTypes []string `json:"types"` + } + + ModelNamesReq { + Type string `form:"type"` + } + + ModelNamesResp { + ModelNames []string `json:"models"` + } + /******************image inference*************************/ ImageInferenceReq { TaskName string `form:"taskName"` @@ -31,4 +43,5 @@ type ( Card string `json:"card"` ImageResult string `json:"imageResult"` } + ) diff --git a/api/desc/pcm.api b/api/desc/pcm.api index 982b3d86..85d2f768 100644 --- a/api/desc/pcm.api +++ b/api/desc/pcm.api @@ -909,6 +909,12 @@ service pcm { service pcm { @handler ImageInferenceHandler post /inference/images (ImageInferenceReq) returns (ImageInferenceResp) + + @handler ModelTypesHandler + get /inference/modelTypes returns (ModelTypesResp) + + @handler ModelNamesByTypeHandler + get /inference/modelNames (ModelNamesReq) returns (ModelNamesResp) } @server( diff --git a/api/etc/pcm.yaml b/api/etc/pcm.yaml index 4dc9164e..ef1e3250 100644 --- a/api/etc/pcm.yaml +++ b/api/etc/pcm.yaml @@ -1,6 +1,7 @@ Name: pcm.core.api Host: 0.0.0.0 Port: 8999 +MaxBytes: 524288000 Timeout: 50000 diff --git a/api/internal/cron/aiCronTask.go b/api/internal/cron/aiCronTask.go index 08cf14b2..b1ac69b3 100644 --- a/api/internal/cron/aiCronTask.go +++ b/api/internal/cron/aiCronTask.go @@ -90,7 +90,7 @@ func UpdateAiTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { var wg sync.WaitGroup for _, aitask := range aiTaskList { t := aitask - if t.Status == constants.Completed || t.Status == constants.Failed { + if t.Status == constants.Completed || t.Status == constants.Failed || t.JobId == "" { continue } wg.Add(1) diff --git a/api/internal/handler/inference/modelnamesbytypehandler.go b/api/internal/handler/inference/modelnamesbytypehandler.go new file mode 100644 index 00000000..86f1e339 --- /dev/null +++ b/api/internal/handler/inference/modelnamesbytypehandler.go @@ -0,0 +1,25 @@ +package inference + +import ( + "net/http" + + "github.com/zeromicro/go-zero/rest/httpx" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" +) + +func ModelNamesByTypeHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req types.ModelNamesReq + if err := httpx.Parse(r, &req); err != nil { + result.ParamErrorResult(r, w, err) + return + } + + l := inference.NewModelNamesByTypeLogic(r.Context(), svcCtx) + resp, err := l.ModelNamesByType(&req) + result.HttpResult(r, w, resp, err) + } +} diff --git a/api/internal/handler/inference/modeltypeshandler.go b/api/internal/handler/inference/modeltypeshandler.go new file mode 100644 index 00000000..1ee33432 --- /dev/null +++ b/api/internal/handler/inference/modeltypeshandler.go @@ -0,0 +1,16 @@ +package inference + +import ( + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" + "net/http" +) + +func ModelTypesHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + l := inference.NewModelTypesLogic(r.Context(), svcCtx) + resp, err := l.ModelTypes() + result.HttpResult(r, w, resp, err) + } +} diff --git a/api/internal/handler/routes.go b/api/internal/handler/routes.go index 94a504ca..081c4d85 100644 --- a/api/internal/handler/routes.go +++ b/api/internal/handler/routes.go @@ -1143,6 +1143,16 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { Path: "/inference/images", Handler: inference.ImageInferenceHandler(serverCtx), }, + { + Method: http.MethodGet, + Path: "/inference/modelTypes", + Handler: inference.ModelTypesHandler(serverCtx), + }, + { + Method: http.MethodGet, + Path: "/inference/modelNames", + Handler: inference.ModelNamesByTypeHandler(serverCtx), + }, }, rest.WithPrefix("/pcm/v1"), ) diff --git a/api/internal/logic/cloud/commitgeneraltasklogic.go b/api/internal/logic/cloud/commitgeneraltasklogic.go index ac35e1b8..4fd561ae 100644 --- a/api/internal/logic/cloud/commitgeneraltasklogic.go +++ b/api/internal/logic/cloud/commitgeneraltasklogic.go @@ -94,7 +94,7 @@ func (l *CommitGeneralTaskLogic) CommitGeneralTask(req *types.GeneralTaskReq) er Name: req.Name, CommitTime: time.Now(), YamlString: strings.Join(req.ReqBody, "\n---\n"), - AdapterTypeDict: 0, + AdapterTypeDict: "0", SynergyStatus: synergyStatus, Strategy: strategy, } diff --git a/api/internal/logic/core/commitvmtasklogic.go b/api/internal/logic/core/commitvmtasklogic.go index 2c2424bf..aebf4f48 100644 --- a/api/internal/logic/core/commitvmtasklogic.go +++ b/api/internal/logic/core/commitvmtasklogic.go @@ -86,7 +86,7 @@ func (l *CommitVmTaskLogic) CommitVmTask(req *types.CommitVmTaskReq) (resp *type Name: req.Name, CommitTime: time.Now(), Description: "vm task", - AdapterTypeDict: 0, + AdapterTypeDict: "0", SynergyStatus: synergyStatus, Strategy: strategy, } diff --git a/api/internal/logic/core/pagelisttasklogic.go b/api/internal/logic/core/pagelisttasklogic.go index 96ccf437..52ad4171 100644 --- a/api/internal/logic/core/pagelisttasklogic.go +++ b/api/internal/logic/core/pagelisttasklogic.go @@ -263,7 +263,7 @@ func (l *PageListTaskLogic) updateAiTaskStatus(tasklist []*types.TaskModel, ch c var wg sync.WaitGroup for _, aitask := range aiTaskList { t := aitask - if t.Status == constants.Completed || t.Status == constants.Failed { + if t.Status == constants.Completed || t.Status == constants.Failed || t.JobId == "" { continue } wg.Add(1) diff --git a/api/internal/logic/core/taskdetailslogic.go b/api/internal/logic/core/taskdetailslogic.go index 0a8c7dbd..2352974b 100644 --- a/api/internal/logic/core/taskdetailslogic.go +++ b/api/internal/logic/core/taskdetailslogic.go @@ -36,14 +36,14 @@ func (l *TaskDetailsLogic) TaskDetails(req *types.FId) (resp *types.TaskDetailsR var cList []*types.ClusterInfo var subList []*types.SubTaskInfo switch task.AdapterTypeDict { - case 0: + case "0": l.svcCtx.DbEngin.Table("task_cloud").Where("task_id", task.Id).Scan(&subList) if len(subList) <= 0 { l.svcCtx.DbEngin.Table("task_vm").Where("task_id", task.Id).Find(&subList) } - case 1: + case "1": l.svcCtx.DbEngin.Table("task_ai").Where("task_id", task.Id).Scan(&subList) - case 2: + case "2": l.svcCtx.DbEngin.Table("task_hpc").Where("task_id", task.Id).Scan(&subList) } for _, sub := range subList { diff --git a/api/internal/logic/core/tasklistlogic.go b/api/internal/logic/core/tasklistlogic.go index 382db2e2..fbbdcc91 100644 --- a/api/internal/logic/core/tasklistlogic.go +++ b/api/internal/logic/core/tasklistlogic.go @@ -122,7 +122,7 @@ func (l *TaskListLogic) TaskList(req *types.TaskListReq) (resp *types.TaskListRe func (l *TaskListLogic) updateAitaskStatus(tasks []models.Task, ch chan<- struct{}) { for _, task := range tasks { - if task.AdapterTypeDict != 1 { + if task.AdapterTypeDict != "1" { continue } if task.Status == constants.Succeeded { diff --git a/api/internal/logic/hpc/commithpctasklogic.go b/api/internal/logic/hpc/commithpctasklogic.go index ebd6f252..00f7e3e1 100644 --- a/api/internal/logic/hpc/commithpctasklogic.go +++ b/api/internal/logic/hpc/commithpctasklogic.go @@ -40,7 +40,7 @@ func (l *CommitHpcTaskLogic) CommitHpcTask(req *types.CommitHpcTaskReq) (resp *t Strategy: 0, SynergyStatus: 0, CommitTime: time.Now(), - AdapterTypeDict: 2, + AdapterTypeDict: "2", } // 保存任务数据到数据库 diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index 5230c0be..2a06f78d 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -10,10 +10,15 @@ import ( "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "math/rand" "mime/multipart" "net/http" + "sort" + "strconv" "sync" + "time" ) type ImageInferenceLogic struct { @@ -88,7 +93,6 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere var strat strategy.Strategy switch opt.Strategy { case strategy.STATIC_WEIGHT: - //todo resources should match cluster StaticWeightMap strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) if err != nil { return nil, err @@ -128,33 +132,69 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s var wg sync.WaitGroup var cluster_ch = make(chan struct { urls []*collector.ImageInferUrl + clusterId string clusterName string imageNum int32 }, len(clusters)) var cs []struct { urls []*collector.ImageInferUrl + clusterId string clusterName string imageNum int32 } collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] + //save task + var synergystatus int64 + if len(clusters) > 1 { + synergystatus = 1 + } + + strategyCode, err := svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy) + if err != nil { + return nil, err + } + adapterName, err := svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) + if err != nil { + return nil, err + } + id, err := svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11") + if err != nil { + return nil, err + } + + svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中") + + //save taskai + for _, c := range clusters { + clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) + opt.Replica = c.Replicas + err := svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "") + if err != nil { + return nil, err + } + } + for _, cluster := range clusters { wg.Add(1) c := cluster go func() { imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) if err != nil { + wg.Done() return } clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) s := struct { urls []*collector.ImageInferUrl + clusterId string clusterName string imageNum int32 }{ urls: imageUrls, + clusterId: c.ClusterId, clusterName: clusterName, imageNum: c.Replicas, } @@ -171,11 +211,42 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s cs = append(cs, s) } + var aiTaskList []*models.TaskAi + tx := svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) + if tx.Error != nil { + return nil, tx.Error + + } + //change cluster status + if len(clusters) != len(cs) { + var acs []*strategy.AssignedCluster + for _, cluster := range clusters { + if contains(cs, cluster.ClusterId) { + continue + } else { + var ac *strategy.AssignedCluster + ac = cluster + acs = append(acs, ac) + } + } + + // update failed cluster status + for _, ac := range acs { + for _, t := range aiTaskList { + if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Failed + err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) + if err != nil { + logx.Errorf(tx.Error.Error()) + } + } + } + } + } + var result_ch = make(chan *types.ImageResult, len(ts)) var results []*types.ImageResult - wg.Add(len(ts)) - var imageNumIdx int32 = 0 var imageNumIdxEnd int32 = 0 for _, c := range cs { @@ -189,16 +260,32 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s new_images = new_images[imageNumIdx:imageNumIdxEnd] imageNumIdx = imageNumIdx + c.imageNum + wg.Add(len(new_images)) go sendInferReq(new_images, c, &wg, result_ch) } wg.Wait() - close(result_ch) for s := range result_ch { results = append(results, s) } + sort.Slice(results, func(p, q int) bool { + return results[p].ClusterName < results[q].ClusterName + }) + + // update succeeded cluster status + for _, c := range cs { + for _, t := range aiTaskList { + if c.clusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Completed + err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) + if err != nil { + logx.Errorf(tx.Error.Error()) + } + } + } + } return results, nil } @@ -207,6 +294,7 @@ func sendInferReq(images []struct { file multipart.File }, cluster struct { urls []*collector.ImageInferUrl + clusterId string clusterName string imageNum int32 }, wg *sync.WaitGroup, ch chan<- *types.ImageResult) { @@ -216,6 +304,7 @@ func sendInferReq(images []struct { file multipart.File }, c struct { urls []*collector.ImageInferUrl + clusterId string clusterName string imageNum int32 }) { @@ -223,6 +312,8 @@ func sendInferReq(images []struct { r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName) if err != nil { t.imageResult.ImageResult = err.Error() + t.imageResult.ClusterName = c.clusterName + t.imageResult.Card = c.urls[0].Card ch <- t.imageResult wg.Done() return @@ -239,6 +330,8 @@ func sendInferReq(images []struct { r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName) if err != nil { t.imageResult.ImageResult = err.Error() + t.imageResult.ClusterName = c.clusterName + t.imageResult.Card = c.urls[idx].Card ch <- t.imageResult wg.Done() return @@ -257,7 +350,7 @@ func sendInferReq(images []struct { func getInferResult(url string, file multipart.File, fileName string) (string, error) { var res Res - req := GetACHttpRequest() + req := GetRestyRequest(10) _, err := req. SetFileReader("file", fileName, file). SetResult(&res). @@ -269,8 +362,8 @@ func getInferResult(url string, file multipart.File, fileName string) (string, e return res.Result, nil } -func GetACHttpRequest() *resty.Request { - client := resty.New() +func GetRestyRequest(timeoutSeconds int64) *resty.Request { + client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) request := client.R() return request } @@ -278,3 +371,17 @@ func GetACHttpRequest() *resty.Request { type Res struct { Result string `json:"result"` } + +func contains(cs []struct { + urls []*collector.ImageInferUrl + clusterId string + clusterName string + imageNum int32 +}, e string) bool { + for _, c := range cs { + if c.clusterId == e { + return true + } + } + return false +} diff --git a/api/internal/logic/inference/modelnamesbytypelogic.go b/api/internal/logic/inference/modelnamesbytypelogic.go new file mode 100644 index 00000000..02e127ae --- /dev/null +++ b/api/internal/logic/inference/modelnamesbytypelogic.go @@ -0,0 +1,36 @@ +package inference + +import ( + "context" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" + + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + + "github.com/zeromicro/go-zero/core/logx" +) + +type ModelNamesByTypeLogic struct { + logx.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewModelNamesByTypeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ModelNamesByTypeLogic { + return &ModelNamesByTypeLogic{ + Logger: logx.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *ModelNamesByTypeLogic) ModelNamesByType(req *types.ModelNamesReq) (resp *types.ModelNamesResp, err error) { + resp = &types.ModelNamesResp{} + models, err := storeLink.GetModelNamesByType(req.Type) + if err != nil { + logx.Errorf("ModelNamesByType err: %v", err) + return nil, err + } + resp.ModelNames = models + return resp, nil +} diff --git a/api/internal/logic/inference/modeltypeslogic.go b/api/internal/logic/inference/modeltypeslogic.go new file mode 100644 index 00000000..430aae8f --- /dev/null +++ b/api/internal/logic/inference/modeltypeslogic.go @@ -0,0 +1,32 @@ +package inference + +import ( + "context" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" + + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + + "github.com/zeromicro/go-zero/core/logx" +) + +type ModelTypesLogic struct { + logx.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewModelTypesLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ModelTypesLogic { + return &ModelTypesLogic{ + Logger: logx.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *ModelTypesLogic) ModelTypes() (resp *types.ModelTypesResp, err error) { + resp = &types.ModelTypesResp{} + mTypes := storeLink.GetModelTypes() + resp.ModelTypes = mTypes + return resp, nil +} diff --git a/api/internal/logic/schedule/schedulesubmitlogic.go b/api/internal/logic/schedule/schedulesubmitlogic.go index 64a1fbc4..fdcd4f2f 100644 --- a/api/internal/logic/schedule/schedulesubmitlogic.go +++ b/api/internal/logic/schedule/schedulesubmitlogic.go @@ -71,7 +71,7 @@ func (l *ScheduleSubmitLogic) ScheduleSubmit(req *types.ScheduleReq) (resp *type if err != nil { return nil, err } - id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(req.AiOption.TaskName, strategyCode, synergystatus) + id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(req.AiOption.TaskName, strategyCode, synergystatus, "10") if err != nil { return nil, err } diff --git a/api/internal/scheduler/database/aiStorage.go b/api/internal/scheduler/database/aiStorage.go index 68e11316..511920b8 100644 --- a/api/internal/scheduler/database/aiStorage.go +++ b/api/internal/scheduler/database/aiStorage.go @@ -94,7 +94,7 @@ func (s *AiStorage) GetAiTasksByAdapterId(adapterId string) ([]*models.TaskAi, e return resp, nil } -func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64) (int64, error) { +func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64, aiType string) (int64, error) { // 构建主任务结构体 taskModel := models.Task{ Status: constants.Saved, @@ -102,7 +102,8 @@ func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int6 Name: name, SynergyStatus: synergyStatus, Strategy: strategyCode, - AdapterTypeDict: 1, + AdapterTypeDict: "1", + TaskTypeDict: aiType, CommitTime: time.Now(), } // 保存任务数据到数据库 @@ -113,9 +114,22 @@ func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int6 return taskModel.Id, nil } -func (s *AiStorage) SaveAiTask(taskId int64, option *option.AiOption, adapterName string, clusterId string, clusterName string, jobId string, status string, msg string) error { +func (s *AiStorage) SaveAiTask(taskId int64, opt option.Option, adapterName string, clusterId string, clusterName string, jobId string, status string, msg string) error { + var aiOpt *option.AiOption + switch (opt).(type) { + case *option.AiOption: + aiOpt = (opt).(*option.AiOption) + case *option.InferOption: + inferOpt := (opt).(*option.InferOption) + aiOpt = &option.AiOption{} + aiOpt.TaskName = inferOpt.TaskName + aiOpt.Replica = inferOpt.Replica + aiOpt.AdapterId = inferOpt.AdapterId + aiOpt.TaskType = inferOpt.ModelType + aiOpt.StrategyName = inferOpt.Strategy + } // 构建主任务结构体 - aId, err := strconv.ParseInt(option.AdapterId, 10, 64) + aId, err := strconv.ParseInt(aiOpt.AdapterId, 10, 64) if err != nil { return err } @@ -130,14 +144,14 @@ func (s *AiStorage) SaveAiTask(taskId int64, option *option.AiOption, adapterNam AdapterName: adapterName, ClusterId: cId, ClusterName: clusterName, - Name: option.TaskName, - Replica: int64(option.Replica), + Name: aiOpt.TaskName, + Replica: int64(aiOpt.Replica), JobId: jobId, - TaskType: option.TaskType, - Strategy: option.StrategyName, + TaskType: aiOpt.TaskType, + Strategy: aiOpt.StrategyName, Status: status, Msg: msg, - Card: option.ComputeCard, + Card: aiOpt.ComputeCard, CommitTime: time.Now(), } // 保存任务数据到数据库 diff --git a/api/internal/scheduler/schedulers/aiScheduler.go b/api/internal/scheduler/schedulers/aiScheduler.go index 201f565d..8be805d6 100644 --- a/api/internal/scheduler/schedulers/aiScheduler.go +++ b/api/internal/scheduler/schedulers/aiScheduler.go @@ -222,7 +222,7 @@ func (as *AiScheduler) AssignTask(clusters []*strategy.AssignedCluster) (interfa synergystatus = 1 } strategyCode, err := as.AiStorages.GetStrategyCode(as.option.StrategyName) - taskId, err := as.AiStorages.SaveTask(as.option.TaskName, strategyCode, synergystatus) + taskId, err := as.AiStorages.SaveTask(as.option.TaskName, strategyCode, synergystatus, "10") if err != nil { return nil, errors.New("database add failed: " + err.Error()) } diff --git a/api/internal/scheduler/schedulers/option/inferOption.go b/api/internal/scheduler/schedulers/option/inferOption.go index b576eb0f..65249955 100644 --- a/api/internal/scheduler/schedulers/option/inferOption.go +++ b/api/internal/scheduler/schedulers/option/inferOption.go @@ -16,3 +16,7 @@ type InferOption struct { Cmd string `json:"cmd,optional"` Replica int32 `json:"replicas,optional"` } + +func (a InferOption) GetOptionType() string { + return AI_INFER +} diff --git a/api/internal/scheduler/schedulers/option/option.go b/api/internal/scheduler/schedulers/option/option.go index 111904b4..dcd188eb 100644 --- a/api/internal/scheduler/schedulers/option/option.go +++ b/api/internal/scheduler/schedulers/option/option.go @@ -1,10 +1,11 @@ package option const ( - AI = "ai" - CLOUD = "cloud" - HPC = "hpc" - VM = "vm" + AI_INFER = "ai_infer" + AI = "ai" + CLOUD = "cloud" + HPC = "hpc" + VM = "vm" ) type Option interface { diff --git a/api/internal/storeLink/modelarts.go b/api/internal/storeLink/modelarts.go index 7ebe0f0f..a2532730 100644 --- a/api/internal/storeLink/modelarts.go +++ b/api/internal/storeLink/modelarts.go @@ -385,10 +385,13 @@ func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.Inf Type: option.ModelType, Card: "npu", } - urlResp, _ := m.modelArtsRpc.ImageReasoningUrl(ctx, urlReq) + urlResp, err := m.modelArtsRpc.ImageReasoningUrl(ctx, urlReq) + if err != nil { + return nil, err + } imageUrl := &collector.ImageInferUrl{ Url: urlResp.Url, - Card: option.ComputeCard, + Card: "npu", } imageUrls = append(imageUrls, imageUrl) diff --git a/api/internal/storeLink/octopus.go b/api/internal/storeLink/octopus.go index e84e54fd..ff25a0d3 100644 --- a/api/internal/storeLink/octopus.go +++ b/api/internal/storeLink/octopus.go @@ -872,21 +872,28 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec } func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { + req := &octopus.GetNotebookListReq{ + Platform: o.platform, + PageIndex: o.pageIndex, + PageSize: o.pageSize, + } + list, err := o.octopusRpc.GetNotebookList(ctx, req) + if err != nil { + return nil, err + } + var imageUrls []*collector.ImageInferUrl - imageUrl := &collector.ImageInferUrl{ - Url: "http://0.0.0.0:8888/image", - Card: "mlu", + for _, notebook := range list.Payload.GetNotebooks() { + if strings.Contains(notebook.AlgorithmName, option.ModelName) { + names := strings.Split(notebook.AlgorithmName, UNDERSCORE) + imageUrl := &collector.ImageInferUrl{ + Url: DOMAIN + notebook.Tasks[0].Url + FORWARD_SLASH + "image", + Card: names[2], + } + imageUrls = append(imageUrls, imageUrl) + } else { + continue + } } - imageUrl1 := &collector.ImageInferUrl{ - Url: "http://0.0.0.0:8888/image", - Card: "gcu", - } - imageUrl2 := &collector.ImageInferUrl{ - Url: "http://0.0.0.0:8888/image", - Card: "biv100", - } - imageUrls = append(imageUrls, imageUrl) - imageUrls = append(imageUrls, imageUrl1) - imageUrls = append(imageUrls, imageUrl2) return imageUrls, nil } diff --git a/api/internal/storeLink/shuguangai.go b/api/internal/storeLink/shuguangai.go index 5c5145d9..6b64dcdb 100644 --- a/api/internal/storeLink/shuguangai.go +++ b/api/internal/storeLink/shuguangai.go @@ -739,7 +739,10 @@ func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferO Card: "dcu", } - urlResp, _ := s.aCRpc.GetInferUrl(ctx, urlReq) + urlResp, err := s.aCRpc.GetInferUrl(ctx, urlReq) + if err != nil { + return nil, err + } imageUrl := &collector.ImageInferUrl{ Url: urlResp.Url, Card: option.ComputeCard, diff --git a/api/internal/storeLink/storeLink.go b/api/internal/storeLink/storeLink.go index 976aa54a..d5ead2e6 100644 --- a/api/internal/storeLink/storeLink.go +++ b/api/internal/storeLink/storeLink.go @@ -76,6 +76,9 @@ var ( 3: "制作完成", 4: "制作失败", } + ModelTypeMap = map[string][]string{ + "image_recognition": {"imagenet_resnet50"}, + } AITYPE = map[string]string{ "1": OCTOPUS, "2": MODELARTS, @@ -128,6 +131,22 @@ func GetResourceTypes() []string { return resourceTypes } +func GetModelTypes() []string { + var mTypes []string + for k, _ := range ModelTypeMap { + mTypes = append(mTypes, k) + } + return mTypes +} + +func GetModelNamesByType(t string) ([]string, error) { + _, ok := ModelTypeMap[t] + if !ok { + return nil, errors.New("model type does not exist") + } + return ModelTypeMap[t], nil +} + func GetDatasetsNames(ctx context.Context, collectorMap map[string]collector.AiCollector) ([]string, error) { var wg sync.WaitGroup var errCh = make(chan interface{}, len(collectorMap)) diff --git a/api/internal/types/types.go b/api/internal/types/types.go index b92be857..5aea737a 100644 --- a/api/internal/types/types.go +++ b/api/internal/types/types.go @@ -5879,6 +5879,18 @@ type Category struct { Name string `json:"name"` } +type ModelTypesResp struct { + ModelTypes []string `json:"types"` +} + +type ModelNamesReq struct { + Type string `form:"type"` +} + +type ModelNamesResp struct { + ModelNames []string `json:"models"` +} + type ImageInferenceReq struct { TaskName string `form:"taskName"` TaskDesc string `form:"taskDesc"` diff --git a/pkg/models/taskmodel_gen.go b/pkg/models/taskmodel_gen.go index 5dae890c..6f5baa88 100644 --- a/pkg/models/taskmodel_gen.go +++ b/pkg/models/taskmodel_gen.go @@ -49,7 +49,8 @@ type ( Result string `db:"result"` // 作业结果 DeletedAt gorm.DeletedAt `gorm:"index"` NsID string `db:"ns_id"` - AdapterTypeDict int `db:"adapter_type_dict"` //任务类型(对应字典表的值) + AdapterTypeDict string `db:"adapter_type_dict"` //任务类型(对应字典表的值) + TaskTypeDict string `db:"task_type_dict"` } )