diff --git a/api/internal/logic/schedule/schedulegetaijoblogloglogic.go b/api/internal/logic/schedule/schedulegetaijoblogloglogic.go index da5a0c7a..e0f304de 100644 --- a/api/internal/logic/schedule/schedulegetaijoblogloglogic.go +++ b/api/internal/logic/schedule/schedulegetaijoblogloglogic.go @@ -26,7 +26,11 @@ func NewScheduleGetAiJobLogLogLogic(ctx context.Context, svcCtx *svc.ServiceCont func (l *ScheduleGetAiJobLogLogLogic) ScheduleGetAiJobLogLog(req *types.AiJobLogReq) (resp *types.AiJobLogResp, err error) { resp = &types.AiJobLogResp{} - log, err := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[req.AdapterId][req.ClusterId].GetTrainingTaskLog(l.ctx, req.TaskId, req.InstanceNum) + id, err := l.svcCtx.Scheduler.AiStorages.GetAiTaskIdByClusterIdAndTaskId(req.ClusterId, req.TaskId) + if err != nil { + return nil, err + } + log, err := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[req.AdapterId][req.ClusterId].GetTrainingTaskLog(l.ctx, id, req.InstanceNum) if err != nil { return nil, err } diff --git a/api/internal/logic/schedule/schedulesubmitlogic.go b/api/internal/logic/schedule/schedulesubmitlogic.go index 2b9956d1..183699f2 100644 --- a/api/internal/logic/schedule/schedulesubmitlogic.go +++ b/api/internal/logic/schedule/schedulesubmitlogic.go @@ -6,6 +6,7 @@ import ( "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "github.com/zeromicro/go-zero/core/logx" ) @@ -51,6 +52,10 @@ func (l *ScheduleSubmitLogic) ScheduleSubmit(req *types.ScheduleReq) (resp *type switch opt.GetOptionType() { case option.AI: + id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(req.AiOption.TaskName) + if err != nil { + return nil, err + } rs := (results).([]*schedulers.AiResult) for _, r := range rs { scheResult := &types.ScheduleResult{} @@ -59,12 +64,13 @@ func (l *ScheduleSubmitLogic) ScheduleSubmit(req *types.ScheduleReq) (resp *type scheResult.Strategy = r.Strategy scheResult.Replica = r.Replica scheResult.Msg = r.Msg + err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, r.ClusterId, r.TaskId, constants.Running, r.Msg) + if err != nil { + return nil, err + } resp.Results = append(resp.Results, scheResult) } - err = l.svcCtx.Scheduler.AiStorages.SaveTask(req.AiOption.TaskName) - if err != nil { - return nil, err - } + } return resp, nil diff --git a/api/internal/scheduler/database/aiStorage.go b/api/internal/scheduler/database/aiStorage.go index 2cf648aa..8efb98b8 100644 --- a/api/internal/scheduler/database/aiStorage.go +++ b/api/internal/scheduler/database/aiStorage.go @@ -2,10 +2,12 @@ package database import ( "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "gorm.io/gorm" + "strconv" "time" ) @@ -48,7 +50,17 @@ func (s *AiStorage) GetAdapterIdsByType(adapterType string) ([]string, error) { return ids, nil } -func (s *AiStorage) SaveTask(name string) error { +func (s *AiStorage) GetAiTasks() ([]*types.AiTaskDb, error) { + var resp []*types.AiTaskDb + tx := s.DbEngin.Raw("select * from task_ai").Scan(&resp) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return nil, tx.Error + } + return resp, nil +} + +func (s *AiStorage) SaveTask(name string) (int64, error) { // 构建主任务结构体 taskModel := models.Task{ Status: constants.Saved, @@ -58,12 +70,52 @@ func (s *AiStorage) SaveTask(name string) error { } // 保存任务数据到数据库 tx := s.DbEngin.Create(&taskModel) + if tx.Error != nil { + return 0, tx.Error + } + return taskModel.Id, nil +} + +func (s *AiStorage) SaveAiTask(taskId int64, option *option.AiOption, clusterId string, jobId string, status string, msg string) error { + // 构建主任务结构体 + aId, err := strconv.ParseInt(option.AdapterId, 10, 64) + if err != nil { + return err + } + cId, err := strconv.ParseInt(clusterId, 10, 64) + if err != nil { + return err + } + aiTaskModel := models.TaskAi{ + TaskId: taskId, + AdapterId: aId, + ClusterId: cId, + Name: option.TaskName, + Replica: option.Replica, + JobId: jobId, + Strategy: option.StrategyName, + Status: status, + Msg: msg, + CommitTime: time.Now(), + } + // 保存任务数据到数据库 + tx := s.DbEngin.Create(&aiTaskModel) if tx.Error != nil { return tx.Error } return nil } +func (s *AiStorage) GetAiTaskIdByClusterIdAndTaskId(clusterId string, taskId string) (string, error) { + var aiTask models.TaskAi + tx := s.DbEngin.Raw("select * from task_ai where `cluster_id` = ? and `task_id` = ?", clusterId, taskId).Scan(&aiTask) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return "", tx.Error + } + return aiTask.JobId, nil +} + func (s *AiStorage) UpdateTask() error { return nil } diff --git a/api/internal/scheduler/schedulers/aiScheduler.go b/api/internal/scheduler/schedulers/aiScheduler.go index a3e3e366..af50d201 100644 --- a/api/internal/scheduler/schedulers/aiScheduler.go +++ b/api/internal/scheduler/schedulers/aiScheduler.go @@ -26,6 +26,7 @@ import ( "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy/param" "gitlink.org.cn/JointCloud/pcm-coordinator/api/pkg/response" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" "gitlink.org.cn/JointCloud/pcm-octopus/octopus" @@ -168,32 +169,46 @@ func (as *AiScheduler) AssignTask(clusters []*strategy.AssignedCluster) (interfa errs = append(errs, e) } - if len(errs) == len(clusters) { - return nil, errors.New("submit task failed") + for s := range ch { + results = append(results, s) } if len(errs) != 0 { - var msg string + taskId, err := as.AiStorages.SaveTask(as.option.TaskName) + if err != nil { + return nil, err + } + var errmsg string for _, err := range errs { e := (err).(struct { err error clusterId string }) - msg += fmt.Sprintf("clusterId: %v , error: %v \n", e.clusterId, e.err.Error()) + msg := fmt.Sprintf("clusterId: %v , error: %v \n", e.clusterId, e.err.Error()) + errmsg += msg + err := as.AiStorages.SaveAiTask(taskId, as.option, e.clusterId, "", constants.Failed, msg) + if err != nil { + return nil, err + } } for s := range ch { if s.Msg != "" { - msg += fmt.Sprintf("clusterId: %v , error: %v \n", s.ClusterId, s.Msg) + msg := fmt.Sprintf("clusterId: %v , error: %v \n", s.ClusterId, s.Msg) + errmsg += msg + err := as.AiStorages.SaveAiTask(taskId, as.option, s.ClusterId, "", constants.Failed, msg) + if err != nil { + return nil, err + } } else { - msg += fmt.Sprintf("clusterId: %v , submitted successfully, taskId: %v \n", s.ClusterId, s.TaskId) + msg := fmt.Sprintf("clusterId: %v , submitted successfully, taskId: %v \n", s.ClusterId, s.TaskId) + errmsg += msg + err := as.AiStorages.SaveAiTask(taskId, as.option, s.ClusterId, s.TaskId, constants.Succeeded, msg) + if err != nil { + return nil, err + } } } - return nil, errors.New(msg) - } - - for s := range ch { - // TODO: database operation - results = append(results, s) + return nil, errors.New(errmsg) } return results, nil diff --git a/api/internal/scheduler/schedulers/option/aiOption.go b/api/internal/scheduler/schedulers/option/aiOption.go index f8a6495f..d2f8d3eb 100644 --- a/api/internal/scheduler/schedulers/option/aiOption.go +++ b/api/internal/scheduler/schedulers/option/aiOption.go @@ -4,6 +4,7 @@ type AiOption struct { AdapterId string ClusterIds []string TaskName string + Replica int64 ResourceType string // cpu/gpu/compute card CpuCoreNum int64 TaskType string // pytorch/tensorflow/mindspore