diff --git a/api/internal/logic/storelink/submitlinktasklogic.go b/api/internal/logic/storelink/submitlinktasklogic.go index 98f77f85..80dfab96 100644 --- a/api/internal/logic/storelink/submitlinktasklogic.go +++ b/api/internal/logic/storelink/submitlinktasklogic.go @@ -67,7 +67,7 @@ func (l *SubmitLinkTaskLogic) SubmitLinkTask(req *types.SubmitLinkTaskReq) (resp envs = append(envs, env) } } - task, err := storelink.ILinkage.SubmitTask(req.ImageId, req.Cmd, envs, params, req.ResourceId) + task, err := storelink.ILinkage.SubmitTask(req.ImageId, req.Cmd, envs, params, req.ResourceId, "") if err != nil { return nil, err } diff --git a/api/internal/scheduler/schedulers/aiScheduler.go b/api/internal/scheduler/schedulers/aiScheduler.go index 4310bd46..8b11faf5 100644 --- a/api/internal/scheduler/schedulers/aiScheduler.go +++ b/api/internal/scheduler/schedulers/aiScheduler.go @@ -30,7 +30,7 @@ type AiScheduler struct { yamlString string task *response.TaskInfo *scheduler.Scheduler - option option.AiOption + option *option.AiOption } func NewAiScheduler(val string, scheduler *scheduler.Scheduler) (*AiScheduler, error) { @@ -74,7 +74,7 @@ func (as *AiScheduler) AssignTask(clusters []*strategy.AssignedCluster) error { executorMap := *as.AiExecutor for _, cluster := range clusters { - _, err := executorMap[cluster.Name].Execute(option.AiOption{}) + _, err := executorMap[cluster.Name].Execute(as.option) if err != nil { // TODO: database operation } diff --git a/api/internal/scheduler/schedulers/option/aiOption.go b/api/internal/scheduler/schedulers/option/aiOption.go index 2d45383c..b1029d37 100644 --- a/api/internal/scheduler/schedulers/option/aiOption.go +++ b/api/internal/scheduler/schedulers/option/aiOption.go @@ -1,17 +1,20 @@ package option type AiOption struct { - aiType string // shuguangAi/octopus - resourceType string // cpu/gpu/compute card - taskType string // pytorch/tensorflow + AiType string // shuguangAi/octopus + ResourceType string // cpu/gpu/compute card + TaskType string // pytorch/tensorflow - imageId string - specId string - datasetsId string - codeId string + ImageId string + SpecId string + DatasetsId string + CodeId string + ResourceId string - cmd string + Cmd string + Envs []string + Params []string - datasets string - code string + Datasets string + Code string } diff --git a/api/internal/scheduler/service/executor/aiExecutor.go b/api/internal/scheduler/service/executor/aiExecutor.go index a52ab062..abe91b0c 100644 --- a/api/internal/scheduler/service/executor/aiExecutor.go +++ b/api/internal/scheduler/service/executor/aiExecutor.go @@ -6,6 +6,6 @@ import ( ) type AiExecutor interface { - Execute(option option.AiOption) (interface{}, error) + Execute(option *option.AiOption) (interface{}, error) storeLink.Linkage } diff --git a/api/internal/storeLink/modelarts.go b/api/internal/storeLink/modelarts.go index 31489205..14a8a181 100644 --- a/api/internal/storeLink/modelarts.go +++ b/api/internal/storeLink/modelarts.go @@ -63,7 +63,7 @@ func (o *ModelArtsLink) QueryImageList() (interface{}, error) { return resp, nil } -func (o *ModelArtsLink) SubmitTask(imageId string, cmd string, envs []string, params []string, resourceId string) (interface{}, error) { +func (o *ModelArtsLink) SubmitTask(imageId string, cmd string, envs []string, params []string, resourceId string, aiType string) (interface{}, error) { // modelArts提交任务 environments := make(map[string]string) parameters := make([]*modelarts.ParametersTrainJob, 0) @@ -153,6 +153,10 @@ func (o *ModelArtsLink) GetResourceSpecs() (*collector.ResourceSpecs, error) { return nil, nil } -func (o *ModelArtsLink) Execute(option option.AiOption) (interface{}, error) { - return nil, nil +func (o *ModelArtsLink) Execute(option *option.AiOption) (interface{}, error) { + task, err := o.SubmitTask(option.ImageId, option.Cmd, option.Envs, option.Params, option.ResourceId, option.AiType) + if err != nil { + return nil, err + } + return task, nil } diff --git a/api/internal/storeLink/octopus.go b/api/internal/storeLink/octopus.go index b40da2ee..cdc97ea9 100644 --- a/api/internal/storeLink/octopus.go +++ b/api/internal/storeLink/octopus.go @@ -107,7 +107,7 @@ func (o *OctopusLink) QueryImageList() (interface{}, error) { return resp, nil } -func (o *OctopusLink) SubmitTask(imageId string, cmd string, envs []string, params []string, resourceId string) (interface{}, error) { +func (o *OctopusLink) SubmitTask(imageId string, cmd string, envs []string, params []string, resourceId string, aiType string) (interface{}, error) { // octopus提交任务 // python参数 @@ -200,6 +200,10 @@ func (o *OctopusLink) GetResourceSpecs() (*collector.ResourceSpecs, error) { return nil, nil } -func (o *OctopusLink) Execute(option option.AiOption) (interface{}, error) { - return nil, nil +func (o *OctopusLink) Execute(option *option.AiOption) (interface{}, error) { + task, err := o.SubmitTask(option.ImageId, option.Cmd, option.Envs, option.Params, option.ResourceId, option.AiType) + if err != nil { + return nil, err + } + return task, nil } diff --git a/api/internal/storeLink/shuguangHpc.go b/api/internal/storeLink/shuguangHpc.go index 7c80b456..f7f0af82 100644 --- a/api/internal/storeLink/shuguangHpc.go +++ b/api/internal/storeLink/shuguangHpc.go @@ -144,7 +144,7 @@ func (s ShuguangHpc) QueryImageList() (interface{}, error) { return nil, nil } -func (s ShuguangHpc) SubmitTask(imageId string, cmd string, envs []string, params []string, resourceId string) (interface{}, error) { +func (s ShuguangHpc) SubmitTask(imageId string, cmd string, envs []string, params []string, resourceId string, aiType string) (interface{}, error) { // shuguangHpc提交任务 //判断是否resourceId匹配自定义资源Id diff --git a/api/internal/storeLink/shuguangai.go b/api/internal/storeLink/shuguangai.go index 3d027d41..57fecfc6 100644 --- a/api/internal/storeLink/shuguangai.go +++ b/api/internal/storeLink/shuguangai.go @@ -76,9 +76,7 @@ func (s *ShuguangAi) QueryImageList() (interface{}, error) { return resp, nil } -func (s *ShuguangAi) SubmitTask(imageId string, cmd string, envs []string, params []string, resourceId string) (interface{}, error) { - // shuguangAi提交任务 - +func (s *ShuguangAi) SubmitPytorchTask(imageId string, cmd string, envs []string, params []string, resourceId string) (interface{}, error) { //判断是否resourceId匹配自定义资源Id if resourceId != SHUGUANGAI_CUSTOM_RESOURCE_ID { return nil, errors.New("shuguangAi资源Id不存在") @@ -133,6 +131,18 @@ func (s *ShuguangAi) SubmitTask(imageId string, cmd string, envs []string, param return resp, nil } +func (s *ShuguangAi) SubmitTask(imageId string, cmd string, envs []string, params []string, resourceId string, aiType string) (interface{}, error) { + // shuguangAi提交任务 + if aiType == PYTORCH { + task, err := s.SubmitPytorchTask(imageId, cmd, envs, params, resourceId) + if err != nil { + return nil, err + } + return task, nil + } + return nil, errors.New("shuguangAi不支持的任务类型") +} + func (s *ShuguangAi) QueryTask(taskId string) (interface{}, error) { // shuguangAi获取任务 req := &hpcAC.GetPytorchTaskReq{ @@ -199,6 +209,10 @@ func (o *ShuguangAi) GetResourceSpecs() (*collector.ResourceSpecs, error) { return nil, nil } -func (o *ShuguangAi) Execute(option option.AiOption) (interface{}, error) { - return nil, nil +func (o *ShuguangAi) Execute(option *option.AiOption) (interface{}, error) { + task, err := o.SubmitTask(option.ImageId, option.Cmd, option.Envs, option.Params, option.ResourceId, option.AiType) + if err != nil { + return nil, err + } + return task, nil } diff --git a/api/internal/storeLink/storeLink.go b/api/internal/storeLink/storeLink.go index 18046da4..3a644d8e 100644 --- a/api/internal/storeLink/storeLink.go +++ b/api/internal/storeLink/storeLink.go @@ -31,7 +31,7 @@ type Linkage interface { UploadImage(path string) (interface{}, error) DeleteImage(imageId string) (interface{}, error) QueryImageList() (interface{}, error) - SubmitTask(imageId string, cmd string, envs []string, params []string, resourceId string) (interface{}, error) + SubmitTask(imageId string, cmd string, envs []string, params []string, resourceId string, aiType string) (interface{}, error) QueryTask(taskId string) (interface{}, error) QuerySpecs() (interface{}, error) DeleteTask(taskId string) (interface{}, error)