diff --git a/api/internal/logic/core/pagelisttasklogic.go b/api/internal/logic/core/pagelisttasklogic.go index ec1b8945..8877a03a 100644 --- a/api/internal/logic/core/pagelisttasklogic.go +++ b/api/internal/logic/core/pagelisttasklogic.go @@ -165,53 +165,65 @@ func (l *PageListTaskLogic) updateTaskStatus(tasks []*types.TaskModel, ch chan<- } func (l *PageListTaskLogic) updateAiTaskStatus(tasks []*types.TaskModel, ch chan<- struct{}) { + for i := len(tasks) - 1; i >= 0; i-- { + if tasks[i].AdapterTypeDict == 0 || tasks[i].Status == constants.Succeeded || tasks[i].Status == constants.Failed { + tasks = append(tasks[:i], tasks[i+1:]...) + } + } + + if len(tasks) == 0 { + ch <- struct{}{} + return + } + + task := tasks[0] + for i, _ := range tasks { + earliest, _ := time.Parse(constants.Layout, task.UpdatedTime) + latest, _ := time.Parse(constants.Layout, tasks[i].UpdatedTime) + if earliest.Before(latest) { + task = tasks[i] + } + } + + var aiTaskList []*models.TaskAi + tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", task.Id).Scan(&aiTaskList) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } + + if len(aiTaskList) == 0 { + ch <- struct{}{} + return + } + var wg sync.WaitGroup - for _, task := range tasks { - if task.AdapterTypeDict != 1 { + for _, aitask := range aiTaskList { + t := aitask + if t.Status == constants.Completed { continue } - if task.Status == constants.Succeeded || task.Status == constants.Failed { - continue - } - - var aiTaskList []*models.TaskAi - tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", task.Id).Scan(&aiTaskList) - if tx.Error != nil { - logx.Errorf(tx.Error.Error()) - return - } - - if len(aiTaskList) == 0 { - continue - } - - for _, aitask := range aiTaskList { - t := aitask - if t.Status == constants.Completed { - continue - } - wg.Add(1) - go func() { - trainingTask, err := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[strconv.FormatInt(t.AdapterId, 10)][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(l.ctx, t.JobId) - if err != nil { - msg := fmt.Sprintf("AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) - logx.Errorf(errors.New(msg).Error()) - wg.Done() - return - } - t.Status = trainingTask.Status - t.StartTime = trainingTask.Start - t.EndTime = trainingTask.End - err = l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t) - if err != nil { - msg := fmt.Sprintf("AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) - logx.Errorf(errors.New(msg).Error()) - wg.Done() - return - } + wg.Add(1) + go func() { + trainingTask, err := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[strconv.FormatInt(t.AdapterId, 10)][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(l.ctx, t.JobId) + if err != nil { + msg := fmt.Sprintf("AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) + logx.Errorf(errors.New(msg).Error()) wg.Done() - }() - } + return + } + t.Status = trainingTask.Status + t.StartTime = trainingTask.Start + t.EndTime = trainingTask.End + err = l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t) + if err != nil { + msg := fmt.Sprintf("AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) + logx.Errorf(errors.New(msg).Error()) + wg.Done() + return + } + wg.Done() + }() } wg.Wait() ch <- struct{}{} diff --git a/api/internal/logic/schedule/schedulesubmitlogic.go b/api/internal/logic/schedule/schedulesubmitlogic.go index 98f5f916..5ce7a6ee 100644 --- a/api/internal/logic/schedule/schedulesubmitlogic.go +++ b/api/internal/logic/schedule/schedulesubmitlogic.go @@ -31,6 +31,7 @@ func (l *ScheduleSubmitLogic) ScheduleSubmit(req *types.ScheduleReq) (resp *type resp = &types.ScheduleResp{} opt := &option.AiOption{ AdapterId: req.AiOption.AdapterId, + ClusterIds: req.AiOption.AiClusterIds, TaskName: req.AiOption.TaskName, ResourceType: req.AiOption.ResourceType, Replica: req.AiOption.Replica, diff --git a/api/internal/scheduler/common/common.go b/api/internal/scheduler/common/common.go index ce2ee5e7..68bfaa32 100644 --- a/api/internal/scheduler/common/common.go +++ b/api/internal/scheduler/common/common.go @@ -88,3 +88,12 @@ func RoundFloat(val float64, precision uint) float64 { ratio := math.Pow(10, float64(precision)) return math.Round(val*ratio) / ratio } + +func Contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} diff --git a/api/internal/scheduler/schedulers/aiScheduler.go b/api/internal/scheduler/schedulers/aiScheduler.go index 39b7e427..f3bd56e1 100644 --- a/api/internal/scheduler/schedulers/aiScheduler.go +++ b/api/internal/scheduler/schedulers/aiScheduler.go @@ -22,6 +22,7 @@ import ( "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-ac/hpcAC" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/common" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" @@ -67,6 +68,24 @@ func (as *AiScheduler) GetNewStructForDb(task *response.TaskInfo, resource strin } func (as *AiScheduler) PickOptimalStrategy() (strategy.Strategy, error) { + if as.option.ComputeCard != "" { + m, ok := as.AiService.AiCollectorAdapterMap[as.option.AdapterId] + if ok { + for _, id := range as.option.ClusterIds { + cm, ok := m[id] + if ok { + cards, err := cm.GetComputeCards(as.ctx) + if err != nil { + return nil, err + } + if common.Contains(cards, as.option.ComputeCard) { + return &strategy.SingleAssignment{Cluster: &strategy.AssignedCluster{ClusterId: id, Replicas: 1}}, nil + } + } + } + } + } + if len(as.option.ClusterIds) == 1 { return &strategy.SingleAssignment{Cluster: &strategy.AssignedCluster{ClusterId: as.option.ClusterIds[0], Replicas: 1}}, nil } diff --git a/api/internal/storeLink/octopus.go b/api/internal/storeLink/octopus.go index 97f949d6..4457ccd2 100644 --- a/api/internal/storeLink/octopus.go +++ b/api/internal/storeLink/octopus.go @@ -746,7 +746,7 @@ func (o *OctopusLink) generateCmd(option *option.AiOption) error { case GCU: option.Cmd = "cd /code; python3 train.py" case MLU: - option.Cmd = "su root; cd /torch/venv3/pytorch/bin; source activate; cd /code; python train.py" + option.Cmd = ". /torch/venv3/pytorch/bin/activate; cd /code; python train.py" default: option.Cmd = TRAIN_CMD }